Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c3c0f82
Unverified
Commit
6c3c0f82
authored
Nov 11, 2025
by
Xin Yang
Committed by
GitHub
Nov 11, 2025
Browse files
[Kernel] Optimize rms_norm kernel (#27931)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
684f2545
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
25 deletions
+86
-25
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+29
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+28
-11
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+29
-14
No files found.
csrc/dispatch_utils.h
View file @
6c3c0f82
...
@@ -88,3 +88,32 @@
...
@@ -88,3 +88,32 @@
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}
csrc/layernorm_kernels.cu
View file @
6c3c0f82
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
VEC_SIZE
>
__global__
void
rms_norm_kernel
(
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
...
@@ -21,7 +21,6 @@ __global__ void rms_norm_kernel(
...
@@ -21,7 +21,6 @@ __global__ void rms_norm_kernel(
float
variance
=
0.0
f
;
float
variance
=
0.0
f
;
const
scalar_t
*
input_row
=
input
+
blockIdx
.
x
*
input_stride
;
const
scalar_t
*
input_row
=
input
+
blockIdx
.
x
*
input_stride
;
constexpr
int
VEC_SIZE
=
8
;
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
...
@@ -45,10 +44,20 @@ __global__ void rms_norm_kernel(
...
@@ -45,10 +44,20 @@ __global__ void rms_norm_kernel(
}
}
__syncthreads
();
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
*
out_row
=
out
+
blockIdx
.
x
*
hidden_size
;
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
auto
*
v_in
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
input_row
);
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
auto
*
v_w
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
weight
);
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
auto
*
v_out
=
reinterpret_cast
<
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
out_row
);
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
/
VEC_SIZE
;
i
+=
blockDim
.
x
)
{
vec_n_t
<
scalar_t
,
VEC_SIZE
>
dst
;
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src1
=
v_in
[
i
];
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src2
=
v_w
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
static_cast
<
float
>
(
src1
.
val
[
j
]);
dst
.
val
[
j
]
=
((
scalar_t
)(
x
*
s_variance
))
*
src2
.
val
[
j
];
}
v_out
[
i
]
=
dst
;
}
}
}
}
...
@@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
...
@@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
int
num_tokens
=
input_view
.
numel
()
/
hidden_size
;
int
num_tokens
=
input_view
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input_view
.
stride
(
-
2
);
int64_t
input_stride
=
input_view
.
stride
(
-
2
);
// For large num_tokens, use smaller blocks to increase SM concurrency.
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input_view
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input_view
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input_view
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
input_view
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
const
int
calculated_vec_size
=
out
.
data_ptr
<
scalar_t
>
(),
input_view
.
data_ptr
<
scalar_t
>
(),
std
::
gcd
(
16
/
sizeof
(
scalar_t
),
hidden_size
);
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
const
int
block_size
=
hidden_size
);
std
::
min
(
hidden_size
/
calculated_vec_size
,
max_block_size
);
dim3
block
(
block_size
);
VLLM_DISPATCH_VEC_SIZE
(
calculated_vec_size
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
,
vec_size
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input_view
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
}
}
...
...
csrc/layernorm_quant_kernels.cu
View file @
6c3c0f82
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
,
typename
fp8_type
>
template
<
typename
scalar_t
,
typename
fp8_type
,
int
VEC_SIZE
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
__global__
void
rms_norm_static_fp8_quant_kernel
(
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
...
@@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel(
const
scalar_t
*
input_row
=
input
+
blockIdx
.
x
*
input_stride
;
const
scalar_t
*
input_row
=
input
+
blockIdx
.
x
*
input_stride
;
constexpr
int
VEC_SIZE
=
8
;
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
...
@@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel(
// invert scale to avoid division
// invert scale to avoid division
float
const
scale_inv
=
1.0
f
/
*
scale
;
float
const
scale_inv
=
1.0
f
/
*
scale
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
auto
*
v_in
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
input_row
);
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
auto
*
v_w
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
weight
);
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
/
VEC_SIZE
;
idx
+=
blockDim
.
x
)
{
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src1
=
v_in
[
idx
];
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src2
=
v_w
[
idx
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
static_cast
<
float
>
(
src1
.
val
[
j
]);
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
src2
.
val
[
j
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
*
VEC_SIZE
+
j
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
}
}
}
...
@@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
...
@@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
int
input_stride
=
input
.
stride
(
-
2
);
int
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
const
int
calculated_vec_size
=
<<<
grid
,
block
,
0
,
stream
>>>
(
std
::
gcd
(
16
/
sizeof
(
scalar_t
),
hidden_size
);
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
const
int
block_size
=
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
std
::
min
(
hidden_size
/
calculated_vec_size
,
max_block_size
);
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
dim3
block
(
block_size
);
hidden_size
);
VLLM_DISPATCH_VEC_SIZE
(
calculated_vec_size
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
,
vec_size
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment