Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e2b85cf8
Unverified
Commit
e2b85cf8
authored
Jun 16, 2024
by
Cody Yu
Committed by
GitHub
Jun 17, 2024
Browse files
Fix w8a8 benchmark and add Llama-3-8B (#5562)
parent
845a3f26
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
8 deletions
+19
-8
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+13
-8
benchmarks/cutlass_benchmarks/weight_shapes.py
benchmarks/cutlass_benchmarks/weight_shapes.py
+6
-0
No files found.
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
e2b85cf8
...
...
@@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
# impl
def
pytorch_
i8
_impl
(
a
:
torch
.
tensor
,
b
:
torch
.
tensor
,
scale_a
:
torch
.
tensor
,
def
pytorch_
mm
_impl
(
a
:
torch
.
tensor
,
b
:
torch
.
tensor
,
scale_a
:
torch
.
tensor
,
scale_b
:
torch
.
tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
tensor
:
return
torch
.
mm
(
a
,
b
)
...
...
@@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_
i8
_impl
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_
mm
_impl
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
))
# cutlass impl
...
...
@@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers
=
[]
# pytorch impl w. bf16
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_mm_impl
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
))
# pytorch impl: bf16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
...
...
@@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
.
to
(
device
=
"cpu"
),
scale_b
.
to
(
device
=
"cpu"
),
torch
.
bfloat16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_bf16_scaled_mm"
))
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_bf16_scaled_mm"
))
# cutlass impl: fp16 output
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
.
to
(
device
=
"cpu"
),
scale_b
.
to
(
device
=
"cpu"
),
torch
.
float16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_fp16_scaled_mm"
))
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_fp16_scaled_mm"
))
return
timers
...
...
benchmarks/cutlass_benchmarks/weight_shapes.py
View file @
e2b85cf8
...
...
@@ -22,6 +22,12 @@ WEIGHT_SHAPES = {
([
4096
,
22016
],
1
),
([
11008
,
4096
],
0
),
],
"meta-llama/Llama-3-8b"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-2-13b-hf"
:
[
([
5120
,
15360
],
1
),
([
5120
,
5120
],
0
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment