Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
572cd426
Commit
572cd426
authored
Dec 15, 2025
by
zhuwenwen
Browse files
restore the initial fp8 related implementation
parent
c441dda9
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
143 additions
and
145 deletions
+143
-145
CMakeLists.txt
CMakeLists.txt
+3
-3
cmake/utils.cmake
cmake/utils.cmake
+1
-1
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+0
-2
csrc/ops.h
csrc/ops.h
+17
-17
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+1
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+17
-17
vllm/_custom_ops.py
vllm/_custom_ops.py
+60
-60
vllm/compilation/fix_functionalization.py
vllm/compilation/fix_functionalization.py
+9
-9
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+25
-25
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+10
-10
No files found.
CMakeLists.txt
View file @
572cd426
...
...
@@ -255,15 +255,15 @@ set(VLLM_EXT_SRC
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu"
"csrc/opt/layernorm_kernels_opt.cu"
#
"csrc/layernorm_quant_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"
# "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
#
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
#
"csrc/quantization/activation_kernels.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu"
...
...
cmake/utils.cmake
View file @
572cd426
...
...
@@ -123,7 +123,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
#
"-DENABLE_FP8"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-Werror=unused-variable"
...
...
csrc/layernorm_quant_kernels.cu
View file @
572cd426
...
...
@@ -6,9 +6,7 @@
*/
#include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include <torch/cuda.h>
...
...
csrc/ops.h
View file @
572cd426
...
...
@@ -221,15 +221,15 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
repetition_penalties
);
//
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
//
torch::Tensor& weight, torch::Tensor& scale,
//
double epsilon);
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
double
epsilon
);
//
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
//
torch::Tensor& input,
//
torch::Tensor& residual,
//
torch::Tensor& weight,
//
torch::Tensor& scale, double epsilon);
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
double
epsilon
);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
@@ -258,8 +258,8 @@ void rotary_embedding_tgi(
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
//
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
//
torch::Tensor& scale);
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
@@ -443,15 +443,15 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
//
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//
torch::Tensor const& scale);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
//
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//
torch::Tensor& scale);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
);
//
void dynamic_per_token_scaled_fp8_quant(
//
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
//
std::optional<torch::Tensor> const& scale_ub);
void
dynamic_per_token_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
,
std
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
...
...
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
572cd426
...
...
@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
//
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
namespace
vllm
{
...
...
csrc/torch_bindings.cpp
View file @
572cd426
...
...
@@ -258,8 +258,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"silu_and_mul"
,
torch
::
kCUDA
,
&
silu_and_mul
);
// ops.def(
//
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
//
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"
);
ops
.
impl
(
"silu_and_mul_quant"
,
torch
::
kCUDA
,
&
silu_and_mul_quant
);
ops
.
def
(
"mul_and_silu(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"mul_and_silu"
,
torch
::
kCUDA
,
&
mul_and_silu
);
...
...
@@ -737,25 +737,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
//
ops.def(
//
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
//
"()");
//
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
ops
.
def
(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"
);
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
//
ops.def(
//
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
//
"-> "
//
"()");
//
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
ops
.
def
(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()"
);
ops
.
impl
(
"dynamic_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_fp8_quant
);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
//
ops.def(
//
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
//
"Tensor! scale, Tensor? scale_ub) -> "
//
"()");
//
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
//
&dynamic_per_token_scaled_fp8_quant);
ops
.
def
(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()"
);
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_per_token_scaled_fp8_quant
);
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
...
...
vllm/_custom_ops.py
View file @
572cd426
...
...
@@ -1692,66 +1692,66 @@ def scaled_fp4_experts_quant(
# fp8
#
def scaled_fp8_quant(
#
input: torch.Tensor,
#
scale: Optional[torch.Tensor] = None,
#
num_token_padding: Optional[int] = None,
#
scale_ub: Optional[torch.Tensor] = None,
#
use_per_token_if_dynamic: bool = False,
#
output: Optional[torch.Tensor] = None,
#
) -> tuple[torch.Tensor, torch.Tensor]:
#
"""
#
Quantize input tensor to FP8 and return quantized tensor and scale.
#
This function supports both static and dynamic quantization: If you
#
provide the scale, it will use static scaling and if you omit it,
#
the scale will be determined dynamically. The function also allows
#
optional padding of the output tensors for downstream kernels that
#
will benefit from padding.
#
Args:
#
input: The input tensor to be quantized to FP8
#
scale: Optional scaling factor for the FP8 quantization
#
scale_ub: Optional upper bound for scaling factor in dynamic
#
per token case
#
num_token_padding: If specified, pad the first dimension
#
of the output to at least this value.
#
use_per_token_if_dynamic: Whether to do per_tensor or per_token
#
in the dynamic quantization case.
#
Returns:
#
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
#
scaling factor.
#
"""
#
# This code assumes batch_dim and num_tokens are flattened
#
assert (input.ndim == 2)
#
shape: Union[tuple[int, int], torch.Size] = input.shape
#
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
#
out_dtype: torch.dtype = current_platform.fp8_dtype()
#
if num_token_padding:
#
shape = (max(num_token_padding, input.shape[0]), shape[1])
#
if output is None:
#
output = torch.empty(shape, device=input.device, dtype=out_dtype)
#
else:
#
assert num_token_padding is None, \
#
"padding not supported if output passed in"
#
assert output.dtype == out_dtype
#
if scale is None:
#
if use_per_token_if_dynamic:
#
scale = torch.empty((shape[0], 1),
#
device=input.device,
#
dtype=torch.float32)
#
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
#
output, input.contiguous(), scale, scale_ub)
#
else:
#
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
#
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
#
else:
#
assert scale.numel() == 1, f"{scale.shape}"
#
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
#
return output, scale
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
if
output
is
None
:
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
else
:
assert
num_token_padding
is
None
,
\
"padding not supported if output passed in"
assert
output
.
dtype
==
out_dtype
if
scale
is
None
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
output
,
input
.
contiguous
(),
scale
,
scale_ub
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
# gptq allspark
...
...
vllm/compilation/fix_functionalization.py
View file @
572cd426
...
...
@@ -62,9 +62,9 @@ class FixFunctionalizationPass(VllmInductorPass):
elif
at_target
==
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
:
mutated_args
=
{
1
:
'input'
,
2
:
'residual'
}
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
#
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
#
mutated_args = {1: 'result', 2: 'residual'}
#
self.defunctionalize(graph, node, mutated_args)
elif
at_target
==
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
:
# noqa: E501
mutated_args
=
{
1
:
'result'
,
2
:
'residual'
}
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
elif
at_target
==
torch
.
ops
.
_C
.
rms_norm_dynamic_per_token_quant
.
default
:
# noqa: E501
mutated_args
=
{
1
:
'result'
,
2
:
'scale'
,
3
:
'residual'
}
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
...
...
@@ -83,12 +83,12 @@ class FixFunctionalizationPass(VllmInductorPass):
node
,
mutated_args
,
args
=
(
'result'
,
'input'
))
#
elif at_target == torch.ops._C.silu_and_mul_quant.default:
#
mutated_args = {1: 'result'}
#
self.defunctionalize(graph,
#
node,
#
mutated_args,
#
args=('result', 'input', 'scale'))
elif
at_target
==
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
:
mutated_args
=
{
1
:
'result'
}
self
.
defunctionalize
(
graph
,
node
,
mutated_args
,
args
=
(
'result'
,
'input'
,
'scale'
))
else
:
continue
# skip the count
...
...
vllm/compilation/fusion.py
View file @
572cd426
...
...
@@ -82,17 +82,17 @@ class QuantKey(NamedTuple):
f
"
{
'a'
if
not
self
.
symmetric
else
''
}
symmetric)"
)
#
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
#
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
#
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
kFp8StaticTensorSym
=
QuantKey
(
FP8_DTYPE
,
True
,
GroupShape
.
PER_TENSOR
,
True
)
kFp8DynamicTensorSym
=
QuantKey
(
FP8_DTYPE
,
False
,
GroupShape
.
PER_TENSOR
,
True
)
kFp8DynamicTokenSym
=
QuantKey
(
FP8_DTYPE
,
False
,
GroupShape
.
PER_TOKEN
,
True
)
QUANT_OPS
:
dict
[
QuantKey
,
OpOverload
]
=
{
#
kFp8StaticTensorSym:
#
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
#
kFp8DynamicTensorSym:
#
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
#
kFp8DynamicTokenSym:
#
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8StaticTensorSym
:
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
# noqa: E501
kFp8DynamicTensorSym
:
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
.
default
,
# noqa: E501
kFp8DynamicTokenSym
:
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
.
default
,
# noqa: E501
}
...
...
@@ -111,14 +111,14 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS
:
dict
[
FusedRMSQuantKey
,
OpOverload
]
=
{
#
FusedRMSQuantKey(kFp8StaticTensorSym, False):
#
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
#
FusedRMSQuantKey(kFp8StaticTensorSym, True):
#
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
#
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
#
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
#
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
#
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
FusedRMSQuantKey
(
kFp8StaticTensorSym
,
False
):
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
.
default
,
# noqa: E501
FusedRMSQuantKey
(
kFp8StaticTensorSym
,
True
):
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
,
# noqa: E501
FusedRMSQuantKey
(
kFp8DynamicTokenSym
,
False
):
torch
.
ops
.
_C
.
rms_norm_dynamic_per_token_quant
.
default
,
# noqa: E501
FusedRMSQuantKey
(
kFp8DynamicTokenSym
,
True
):
torch
.
ops
.
_C
.
rms_norm_dynamic_per_token_quant
.
default
,
# noqa: E501
}
...
...
@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
for
epsilon
in
[
1e-5
,
1e-6
]:
# Fuse rms_norm + static fp8 quant
#
RMSNormStaticQuantPattern(epsilon,
#
FP8_DTYPE).register(self.patterns)
RMSNormStaticQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
)
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
#
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
#
self.patterns, self.record_match)
FusedAddRMSNormStaticQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
,
self
.
record_match
)
# Fuse rms_norm + dynamic per-token fp8 quant
#
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
#
self.patterns, self.record_match)
RMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
,
self
.
record_match
)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
#
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
#
self.patterns, self.record_match)
FusedAddRMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
self
.
patterns
,
self
.
record_match
)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
...
...
vllm/compilation/sequence_parallelism.py
View file @
572cd426
...
...
@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
for
epsilon
in
[
1e-5
,
1e-6
]:
# RMSNorm + Static FP8 quantization patterns
#
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
#
FirstAllReduceRMSNormStaticFP8Pattern(
#
epsilon, self.model_dtype, self.device,
#
fp8_quant_op).register(self.patterns)
#
MiddleAllReduceRMSNormStaticFP8Pattern(
#
epsilon, self.model_dtype, self.device,
#
fp8_quant_op).register(self.patterns)
#
LastAllReduceRMSNormStaticFP8Pattern(
#
epsilon, self.model_dtype, self.device,
#
fp8_quant_op).register(self.patterns)
fp8_quant_op
=
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
FirstAllReduceRMSNormStaticFP8Pattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
fp8_quant_op
).
register
(
self
.
patterns
)
MiddleAllReduceRMSNormStaticFP8Pattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
fp8_quant_op
).
register
(
self
.
patterns
)
LastAllReduceRMSNormStaticFP8Pattern
(
epsilon
,
self
.
model_dtype
,
self
.
device
,
fp8_quant_op
).
register
(
self
.
patterns
)
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment