Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
58f9060e
"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6b06c30a65f3ae90cc2bc2cf3359cff741b4e139"
Unverified
Commit
58f9060e
authored
Jan 07, 2025
by
Ke Bao
Committed by
GitHub
Jan 07, 2025
Browse files
Update int8 gemm config (#2774)
parent
bdc1acf6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
5 deletions
+16
-5
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
+14
-3
sgl-kernel/tests/test_int8_gemm.py
sgl-kernel/tests/test_int8_gemm.py
+2
-2
No files found.
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
View file @
58f9060e
...
@@ -88,10 +88,11 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
...
@@ -88,10 +88,11 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
mat_a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
)
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
"gemm cannot implement, error: "
,
cutlassGetStatusString
(
can_implement
));
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
auto
status
=
gemm_op
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
"gemm executioin failed, error: "
,
cutlassGetStatusString
(
status
));
}
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
...
@@ -144,7 +145,17 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t
...
@@ -144,7 +145,17 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
scales_b
,
bias
);
}
}
}
else
if
(
m
<=
64
||
(
m
<=
128
&&
n
<
8192
))
{
}
else
if
(
m
<=
64
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
}
else
if
(
m
<=
128
&&
n
<
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
scales_b
,
bias
);
...
...
sgl-kernel/tests/test_int8_gemm.py
View file @
58f9060e
...
@@ -37,8 +37,8 @@ class TestInt8Gemm(unittest.TestCase):
...
@@ -37,8 +37,8 @@ class TestInt8Gemm(unittest.TestCase):
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ms
=
[
1
,
128
,
512
,
1024
,
4096
,
8192
]
Ns
=
[
16
,
128
,
512
,
1024
,
4096
]
Ns
=
[
16
,
128
,
512
,
1024
,
4096
,
8192
,
16384
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
bias_opts
=
[
True
,
False
]
bias_opts
=
[
True
,
False
]
out_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
out_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment