Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
44401358
Unverified
Commit
44401358
authored
Aug 08, 2025
by
triple-mu
Committed by
GitHub
Aug 08, 2025
Browse files
Fix typos and unify size(s)/stride(s) API calls (#8799)
parent
9c7e3924
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
34 deletions
+34
-34
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+5
-5
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
+3
-3
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+1
-1
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+1
-1
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
+1
-1
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
+23
-23
No files found.
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
44401358
...
@@ -105,10 +105,10 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -105,10 +105,10 @@ typename T::Fmha::Arguments args_from_options(
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope
.
size
s
()[
0
]
;
int
batches
=
q_nope
.
size
(
0
)
;
int
page_count_per_seq
=
page_table
.
size
s
()[
1
]
;
int
page_count_per_seq
=
page_table
.
size
(
1
)
;
int
page_count_total
=
kv_c_and_k_pe_cache
.
size
s
()[
0
]
;
int
page_count_total
=
kv_c_and_k_pe_cache
.
size
(
0
)
;
int
page_size
=
kv_c_and_k_pe_cache
.
size
s
()[
1
]
;
int
page_size
=
kv_c_and_k_pe_cache
.
size
(
1
)
;
int
max_seq_len
=
page_size
*
page_count_per_seq
;
int
max_seq_len
=
page_size
*
page_count_per_seq
;
using
TileShapeH
=
typename
T
::
TileShapeH
;
using
TileShapeH
=
typename
T
::
TileShapeH
;
using
TileShapeD
=
typename
T
::
TileShapeD
;
using
TileShapeD
=
typename
T
::
TileShapeD
;
...
@@ -220,7 +220,7 @@ void cutlass_mla_decode(
...
@@ -220,7 +220,7 @@ void cutlass_mla_decode(
auto
in_dtype
=
q_nope
.
dtype
();
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
const
int
page_size
=
kv_c_and_k_pe_cache
.
size
s
()[
1
]
;
const
int
page_size
=
kv_c_and_k_pe_cache
.
size
(
1
)
;
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
...
...
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
View file @
44401358
...
@@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch:
...
@@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch:
TORCH_CHECK
(
output
.
size
(
0
)
==
num_tokens
,
"required output.shape[0] == mat_a.shape[0]"
)
TORCH_CHECK
(
output
.
size
(
0
)
==
num_tokens
,
"required output.shape[0] == mat_a.shape[0]"
)
TORCH_CHECK
(
output
.
size
(
1
)
==
hd_out
,
"required output.shape[1] == mat_b.shape[1]"
)
TORCH_CHECK
(
output
.
size
(
1
)
==
hd_out
,
"required output.shape[1] == mat_b.shape[1]"
)
TORCH_CHECK
(
mat_a
.
stride
s
()[
1
]
==
1
);
// Row-major
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
// Row-major
TORCH_CHECK
(
output
.
stride
s
()[
1
]
==
1
);
// Row-major
TORCH_CHECK
(
output
.
stride
(
1
)
==
1
,
"output must be a row major tensor"
);
// Row-major
TORCH_CHECK
(
mat_b
.
stride
s
()[
0
]
==
1
);
// Column-major
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_b must be a column major tensor"
);
// Column-major
auto
const
data_type
=
mat_a
.
scalar_type
();
auto
const
data_type
=
mat_a
.
scalar_type
();
TORCH_CHECK
(
TORCH_CHECK
(
...
...
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
44401358
...
@@ -353,7 +353,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -353,7 +353,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_
a
must be a column major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_
b
must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
TORCH_CHECK
(
...
...
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
View file @
44401358
...
@@ -1080,7 +1080,7 @@ torch::Tensor fp8_scaled_mm(
...
@@ -1080,7 +1080,7 @@ torch::Tensor fp8_scaled_mm(
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_
a
must be a column major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_
b
must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
TORCH_CHECK
(
...
...
sgl-kernel/csrc/gemm/int8_gemm_kernel.cu
View file @
44401358
...
@@ -672,7 +672,7 @@ torch::Tensor int8_scaled_mm(
...
@@ -672,7 +672,7 @@ torch::Tensor int8_scaled_mm(
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
dim
()
==
2
,
"mat_a must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_b
.
dim
()
==
2
,
"mat_b must be a 2D tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_a
.
stride
(
1
)
==
1
,
"mat_a must be a row major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_
a
must be a column major tensor"
);
TORCH_CHECK
(
mat_b
.
stride
(
0
)
==
1
,
"mat_
b
must be a column major tensor"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
==
mat_b
.
size
(
0
),
"mat_a and mat_b shapes cannot be multiplied"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
%
16
==
0
,
"mat_a.size(1) must be multiple of 16 for memory alignment"
);
TORCH_CHECK
(
mat_a
.
size
(
1
)
%
16
==
0
,
"mat_a.size(1) must be multiple of 16 for memory alignment"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
%
16
==
0
,
"mat_b.size(0) must be multiple of 16 for memory alignment"
);
TORCH_CHECK
(
mat_b
.
size
(
0
)
%
16
==
0
,
"mat_b.size(0) must be multiple of 16 for memory alignment"
);
...
...
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
View file @
44401358
...
@@ -273,20 +273,20 @@ void cutlass_scaled_fp4_mm_sm100a(
...
@@ -273,20 +273,20 @@ void cutlass_scaled_fp4_mm_sm100a(
TORCH_CHECK
(
A
.
dim
()
==
2
,
"a must be a matrix"
);
TORCH_CHECK
(
A
.
dim
()
==
2
,
"a must be a matrix"
);
TORCH_CHECK
(
B
.
dim
()
==
2
,
"b must be a matrix"
);
TORCH_CHECK
(
B
.
dim
()
==
2
,
"b must be a matrix"
);
TORCH_CHECK
(
TORCH_CHECK
(
A
.
size
s
()[
1
]
==
B
.
size
s
()[
1
]
,
A
.
size
(
1
)
==
B
.
size
(
1
)
,
"a and b shapes cannot be multiplied ("
,
"a and b shapes cannot be multiplied ("
,
A
.
size
s
()[
0
]
,
A
.
size
(
0
)
,
"x"
,
"x"
,
A
.
size
s
()[
1
]
,
A
.
size
(
1
)
,
" and "
,
" and "
,
B
.
size
s
()[
0
]
,
B
.
size
(
0
)
,
"x"
,
"x"
,
B
.
size
s
()[
1
]
,
B
.
size
(
1
)
,
")"
);
")"
);
auto
const
m
=
A
.
size
s
()[
0
]
;
auto
const
m
=
A
.
size
(
0
)
;
auto
const
n
=
B
.
size
s
()[
0
]
;
auto
const
n
=
B
.
size
(
0
)
;
auto
const
k
=
A
.
size
s
()[
1
]
*
2
;
auto
const
k
=
A
.
size
(
1
)
*
2
;
constexpr
int
alignment
=
32
;
constexpr
int
alignment
=
32
;
TORCH_CHECK
(
TORCH_CHECK
(
...
@@ -294,9 +294,9 @@ void cutlass_scaled_fp4_mm_sm100a(
...
@@ -294,9 +294,9 @@ void cutlass_scaled_fp4_mm_sm100a(
"Expected k to be divisible by "
,
"Expected k to be divisible by "
,
alignment
,
alignment
,
", but got a shape: ("
,
", but got a shape: ("
,
A
.
size
s
()[
0
]
,
A
.
size
(
0
)
,
"x"
,
"x"
,
A
.
size
s
()[
1
]
,
A
.
size
(
1
)
,
"), k: "
,
"), k: "
,
k
,
k
,
"."
);
"."
);
...
@@ -305,9 +305,9 @@ void cutlass_scaled_fp4_mm_sm100a(
...
@@ -305,9 +305,9 @@ void cutlass_scaled_fp4_mm_sm100a(
"Expected n to be divisible by "
,
"Expected n to be divisible by "
,
alignment
,
alignment
,
", but got b shape: ("
,
", but got b shape: ("
,
B
.
size
s
()[
0
]
,
B
.
size
(
0
)
,
"x"
,
"x"
,
B
.
size
s
()[
1
]
,
B
.
size
(
1
)
,
")."
);
")."
);
auto
round_up
=
[](
int
x
,
int
y
)
{
return
(
x
+
y
-
1
)
/
y
*
y
;
};
auto
round_up
=
[](
int
x
,
int
y
)
{
return
(
x
+
y
-
1
)
/
y
*
y
;
};
...
@@ -320,37 +320,37 @@ void cutlass_scaled_fp4_mm_sm100a(
...
@@ -320,37 +320,37 @@ void cutlass_scaled_fp4_mm_sm100a(
TORCH_CHECK
(
A_sf
.
dim
()
==
2
,
"scale_a must be a matrix"
);
TORCH_CHECK
(
A_sf
.
dim
()
==
2
,
"scale_a must be a matrix"
);
TORCH_CHECK
(
B_sf
.
dim
()
==
2
,
"scale_b must be a matrix"
);
TORCH_CHECK
(
B_sf
.
dim
()
==
2
,
"scale_b must be a matrix"
);
TORCH_CHECK
(
TORCH_CHECK
(
A_sf
.
size
s
()[
1
]
==
B_sf
.
size
s
()[
1
]
,
A_sf
.
size
(
1
)
==
B_sf
.
size
(
1
)
,
"scale_a and scale_b shapes cannot be multiplied ("
,
"scale_a and scale_b shapes cannot be multiplied ("
,
A_sf
.
size
s
()[
0
]
,
A_sf
.
size
(
0
)
,
"x"
,
"x"
,
A_sf
.
size
s
()[
1
]
,
A_sf
.
size
(
1
)
,
" and "
,
" and "
,
B_sf
.
size
s
()[
0
]
,
B_sf
.
size
(
0
)
,
"x"
,
"x"
,
B_sf
.
size
s
()[
1
]
,
B_sf
.
size
(
1
)
,
")"
);
")"
);
TORCH_CHECK
(
TORCH_CHECK
(
A_sf
.
size
s
()[
0
]
==
rounded_m
&&
A_sf
.
size
s
()[
1
]
==
rounded_k
,
A_sf
.
size
(
0
)
==
rounded_m
&&
A_sf
.
size
(
1
)
==
rounded_k
,
"scale_a must be padded and swizzled to a shape ("
,
"scale_a must be padded and swizzled to a shape ("
,
rounded_m
,
rounded_m
,
"x"
,
"x"
,
rounded_k
,
rounded_k
,
"), but got a shape ("
,
"), but got a shape ("
,
A_sf
.
size
s
()[
0
]
,
A_sf
.
size
(
0
)
,
"x"
,
"x"
,
A_sf
.
size
s
()[
1
]
,
A_sf
.
size
(
1
)
,
")"
);
")"
);
TORCH_CHECK
(
TORCH_CHECK
(
B_sf
.
size
s
()[
0
]
==
rounded_n
&&
B_sf
.
size
s
()[
1
]
==
rounded_k
,
B_sf
.
size
(
0
)
==
rounded_n
&&
B_sf
.
size
(
1
)
==
rounded_k
,
"scale_b must be padded and swizzled to a shape ("
,
"scale_b must be padded and swizzled to a shape ("
,
rounded_n
,
rounded_n
,
"x"
,
"x"
,
rounded_k
,
rounded_k
,
"), but got a shape ("
,
"), but got a shape ("
,
B_sf
.
size
s
()[
0
]
,
B_sf
.
size
(
0
)
,
"x"
,
"x"
,
B_sf
.
size
s
()[
1
]
,
B_sf
.
size
(
1
)
,
")"
);
")"
);
auto
out_dtype
=
D
.
dtype
();
auto
out_dtype
=
D
.
dtype
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment