Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3125d799
Unverified
Commit
3125d799
authored
Oct 18, 2025
by
Isotr0py
Committed by
GitHub
Oct 17, 2025
Browse files
[Chore] Remove unused `PolyNorm` layer (#27110)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
e33ee23e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1 addition
and
524 deletions
+1
-524
benchmarks/kernels/benchmark_polynorm.py
benchmarks/kernels/benchmark_polynorm.py
+0
-155
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+0
-252
csrc/ops.h
csrc/ops.h
+0
-3
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-6
tests/kernels/core/test_layernorm.py
tests/kernels/core/test_layernorm.py
+1
-33
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-12
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+0
-63
No files found.
benchmarks/kernels/benchmark_polynorm.py
deleted
100644 → 0
View file @
e33ee23e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
torch
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
def
polynorm_naive
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
def
norm
(
x
,
eps
:
float
):
return
x
/
torch
.
sqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
float
()
return
(
(
weight
[
0
]
*
norm
(
x
**
3
,
eps
)
+
weight
[
1
]
*
norm
(
x
**
2
,
eps
)
+
weight
[
2
]
*
norm
(
x
,
eps
)
+
bias
)
.
to
(
weight
.
dtype
)
.
view
(
orig_shape
)
)
def
polynorm_vllm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
out
=
torch
.
empty_like
(
x
)
vllm_ops
.
poly_norm
(
out
,
x
,
weight
,
bias
,
eps
)
output
=
out
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_dim
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
output_naive
=
polynorm_naive
(
x
,
weight
,
bias
)
output_vllm
=
polynorm_vllm
(
x
,
weight
,
bias
)
if
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
dim_range
=
[
2048
,
4096
]
configs
=
list
(
itertools
.
product
(
dim_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
():
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"dim"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"naive"
,
"vllm"
],
line_names
=
[
"Naive"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"polynorm-perf"
,
args
=
{},
)
)
def
benchmark
(
dim
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_dim
=
dim
*
4
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
3
,
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"naive"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_naive
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
polynorm_vllm
(
x
,
weight
,
bias
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
4
,
help
=
"Batch size"
,
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
128
,
help
=
"Sequence length"
,
)
parser
.
add_argument
(
"--hidden-dim"
,
type
=
int
,
default
=
8192
,
help
=
"Intermediate size of MLP"
,
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
"./configs/polnorm/"
,
help
=
"Path to save polnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_dim
=
args
.
hidden_dim
,
)
benchmark
=
get_benchmark
()
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
csrc/layernorm_kernels.cu
View file @
3125d799
...
@@ -148,211 +148,6 @@ fused_add_rms_norm_kernel(
...
@@ -148,211 +148,6 @@ fused_add_rms_norm_kernel(
}
}
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16VecPN
:
_f16Vec
<
scalar_t
,
width
>
{
using
Base
=
_f16Vec
<
scalar_t
,
width
>
;
using
Converter
=
typename
Base
::
Converter
;
using
T1
=
typename
Base
::
T1
;
using
T2
=
typename
Base
::
T2
;
using
Base
::
data
;
__device__
auto
sum_pows
()
const
{
float
s2
=
0.0
f
,
s4
=
0.0
f
,
s6
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
float
y2
=
z
.
y
*
z
.
y
;
float
y4
=
y2
*
y2
;
float
y6
=
y4
*
y2
;
s2
+=
x2
+
y2
;
s4
+=
x4
+
y4
;
s6
+=
x6
+
y6
;
}
return
std
::
make_tuple
(
s2
,
s4
,
s6
);
}
__device__
void
poly_norm_inplace
(
const
float
w2_inv_std
,
const
float
w1_inv_std2
,
const
float
w0_inv_std3
,
const
float
bias
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x3
=
x2
*
z
.
x
;
z
.
x
=
w2_inv_std
*
z
.
x
+
w1_inv_std2
*
x2
+
w0_inv_std3
*
x3
+
bias
;
float
y2
=
z
.
y
*
z
.
y
;
float
y3
=
y2
*
z
.
y
;
z
.
y
=
w2_inv_std
*
z
.
y
+
w1_inv_std2
*
y2
+
w0_inv_std3
*
y3
+
bias
;
auto
out
=
Converter
::
convert
(
z
);
data
[
i
]
=
out
.
x
;
data
[
i
+
1
]
=
out
.
y
;
}
}
};
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16VecPN
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16VecPN
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
const
_f16VecPN
<
scalar_t
,
width
>*>
(
input
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
auto
[
x2
,
x4
,
x6
]
=
temp
.
sum_pows
();
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
auto
*
__restrict__
out_v
=
reinterpret_cast
<
_f16VecPN
<
scalar_t
,
width
>*>
(
out
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
temp
.
poly_norm_inplace
(
s_w2_inv_std
,
s_w1_inv_std2
,
s_w0_inv_std3
,
s_bias
);
out_v
[
id
]
=
temp
;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x3
=
x2
*
x
;
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
(
scalar_t
)(
x
*
s_w2_inv_std
+
x2
*
s_w1_inv_std2
+
x3
*
s_w0_inv_std3
+
s_bias
);
}
}
}
// namespace vllm
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
...
@@ -444,50 +239,3 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
...
@@ -444,50 +239,3 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void
poly_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [3]
torch
::
Tensor
&
bias
,
// [1]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
data_ptr
()
!=
input
.
data_ptr
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
out_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
out
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
out_ptr
%
16
==
0
;
bool
batch_invariant_launch
=
vllm
::
vllm_is_batch_invariant
();
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_POLY_NORM
(
8
);
}
else
{
LAUNCH_FUSED_POLY_NORM
(
0
);
}
}
csrc/ops.h
View file @
3125d799
...
@@ -92,9 +92,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
...
@@ -92,9 +92,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
void
poly_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
bias
,
double
epsilon
);
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
output_mask
,
...
...
csrc/torch_bindings.cpp
View file @
3125d799
...
@@ -175,12 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -175,12 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
// Polynomial Normalization.
ops
.
def
(
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
"epsilon) -> ()"
);
ops
.
impl
(
"poly_norm"
,
torch
::
kCUDA
,
&
poly_norm
);
// Apply repetition penalties to logits in-place
// Apply repetition penalties to logits in-place
ops
.
def
(
ops
.
def
(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
...
...
tests/kernels/core/test_layernorm.py
View file @
3125d799
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
tests.kernels.quant_utils
import
FP8_DTYPE
from
tests.kernels.quant_utils
import
FP8_DTYPE
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
PolyNorm
,
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -70,38 +70,6 @@ def test_rms_norm(
...
@@ -70,38 +70,6 @@ def test_rms_norm(
)
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_poly_norm
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
layer
=
PolyNorm
().
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
layer
.
bias
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
*=
scale
ref_out
=
layer
.
forward_native
(
x
)
out
=
layer
(
x
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
opcheck
(
torch
.
ops
.
_C
.
poly_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
bias
.
data
,
layer
.
variance_epsilon
),
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
...
...
vllm/_custom_ops.py
View file @
3125d799
...
@@ -339,18 +339,6 @@ def fused_add_rms_norm(
...
@@ -339,18 +339,6 @@ def fused_add_rms_norm(
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
def
poly_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
epsilon
:
float
,
)
->
None
:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
input_contiguous
=
input
.
contiguous
()
torch
.
ops
.
_C
.
poly_norm
(
out
,
input_contiguous
,
weight
,
bias
,
epsilon
)
def
apply_repetition_penalties_torch
(
def
apply_repetition_penalties_torch
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
prompt_mask
:
torch
.
Tensor
,
prompt_mask
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/layernorm.py
View file @
3125d799
...
@@ -58,22 +58,6 @@ def fused_add_rms_norm(
...
@@ -58,22 +58,6 @@ def fused_add_rms_norm(
return
x
,
residual
return
x
,
residual
def
poly_norm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
ops
.
poly_norm
(
out
,
x
,
weight
,
bias
,
variance_epsilon
,
)
return
out
def
rocm_aiter_rms_norm_impl
(
def
rocm_aiter_rms_norm_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -385,53 +369,6 @@ class GemmaRMSNorm(CustomOp):
...
@@ -385,53 +369,6 @@ class GemmaRMSNorm(CustomOp):
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
@
CustomOp
.
register
(
"poly_norm"
)
class
PolyNorm
(
CustomOp
):
"""Polynomial normalization.
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
where w_n is the learned weight and b is the bias.
Refer to https://arxiv.org/html/2411.03884v1
"""
def
__init__
(
self
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
3
)
/
3
)
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
variance_epsilon
=
eps
def
_norm
(
self
,
x
):
return
x
/
torch
.
sqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
variance_epsilon
)
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward().
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
"""
orig_dtype
=
x
.
dtype
x_float
=
x
.
to
(
torch
.
float32
)
output
=
(
self
.
weight
[
0
]
*
self
.
_norm
(
x_float
**
3
)
+
self
.
weight
[
1
]
*
self
.
_norm
(
x_float
**
2
)
+
self
.
weight
[
2
]
*
self
.
_norm
(
x_float
)
+
self
.
bias
)
return
output
.
to
(
orig_dtype
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
poly_norm
(
x
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
class
LayerNorm
(
nn
.
Module
):
class
LayerNorm
(
nn
.
Module
):
"""
"""
Layer Normalization.
Layer Normalization.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment