Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
148a5c12
Unverified
Commit
148a5c12
authored
Mar 28, 2026
by
IriKa
Committed by
GitHub
Mar 27, 2026
Browse files
[Bugfix]fix output Nan/Inf in marlin if dtype=float16 (#33972)
Signed-off-by:
IriKa Qiu
<
qiujie.jq@gmail.com
>
parent
b69bf2f0
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
83 additions
and
55 deletions
+83
-55
csrc/moe/marlin_moe_wna16/kernel.h
csrc/moe/marlin_moe_wna16/kernel.h
+1
-1
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+22
-13
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+4
-4
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+0
-3
csrc/quantization/marlin/kernel.h
csrc/quantization/marlin/kernel.h
+1
-1
csrc/quantization/marlin/marlin.cu
csrc/quantization/marlin/marlin.cu
+5
-5
csrc/quantization/marlin/marlin_template.h
csrc/quantization/marlin/marlin_template.h
+15
-10
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
...el_executor/layers/quantization/utils/marlin_utils_fp4.py
+35
-18
No files found.
csrc/moe/marlin_moe_wna16/kernel.h
View file @
148a5c12
...
...
@@ -13,7 +13,7 @@
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const
uint16_
t *__restrict__ global_scale_ptr, \
const
floa
t *__restrict__ global_scale_ptr,
\
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
148a5c12
...
...
@@ -260,7 +260,7 @@ __global__ void Marlin(
// fp16 quantization scales. shape (k/groupsize, n)
const
int4
*
__restrict__
scales_ptr
,
// fp16 global scale (for nvfp4// only)
const
uint16_
t
*
__restrict__
global_scale_ptr
,
const
floa
t
*
__restrict__
global_scale_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const
int4
*
__restrict__
zp_ptr
,
...
...
@@ -308,7 +308,14 @@ __global__ void Marlin(
constexpr
int
moe_block_size
=
m_block_size_8
?
8
:
(
16
*
thread_m_blocks
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr
bool
use_fp16_accum
=
a_type_id
==
vllm
::
kFloat16
.
id
();
static
constexpr
auto
num_bits
=
vllm
::
ScalarType
::
from_id
(
b_type_id
).
size_bits
();
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
// num_bits == 4
constexpr
bool
use_fp16_accum
=
a_type_id
==
vllm
::
kFloat16
.
id
()
&&
(
!
(
b_type_id
==
vllm
::
kFE2M1f
.
id
()
&&
s_type_id
==
vllm
::
kFE4M3fn
.
id
())
&&
!
(
group_blocks
==
-
1
&&
num_bits
==
4
));
#else
constexpr
bool
use_fp16_accum
=
false
;
#endif
...
...
@@ -357,7 +364,7 @@ __global__ void Marlin(
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
b_type
==
vllm
::
kU8
);
c_scalar_t2
global_scale
;
float
global_scale
_f32
=
1.0
f
;
constexpr
bool
has_act_order
=
group_blocks
==
0
;
...
...
@@ -507,11 +514,12 @@ __global__ void Marlin(
if
(
mul_topk_weights
)
{
idx
=
idx
<
prob_m_top_k
?
idx
:
0
;
c_scalar_t2
topk_weight_val
=
Cdtype
::
num2num2
(
Cdtype
::
float2num
(
topk_weights_ptr
[
idx
]));
float
topk_weight_tmp
=
topk_weights_ptr
[
idx
];
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
topk_weight_
val
=
__hmul2
(
topk_weight_val
,
global_scale
)
;
topk_weight_
tmp
*=
global_scale
_f32
;
}
c_scalar_t2
topk_weight_val
=
Cdtype
::
num2num2
(
Cdtype
::
float2num
(
topk_weight_tmp
));
sh_block_topk_weights
[
threadIdx
.
x
]
=
topk_weight_val
;
}
}
...
...
@@ -532,8 +540,7 @@ __global__ void Marlin(
expert_id
=
expert_ids_ptr
[
block_id
];
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
uint16_t
val
=
global_scale_ptr
[
expert_id
];
global_scale
=
Cdtype
::
num2num2
(
*
reinterpret_cast
<
c_scalar_t
*>
(
&
val
));
global_scale_f32
=
global_scale_ptr
[
expert_id
];
}
B_expert_off
=
expert_id
*
prob_n
*
prob_k
/
(
pack_factor
*
4
);
...
...
@@ -1784,6 +1791,13 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
,
FragS
&
b_bias
)
{
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
if
(
!
mul_topk_weights
)
{
c0
*=
global_scale_f32
;
c1
*=
global_scale_f32
;
}
}
c_scalar_t2
res
=
Cdtype
::
nums2num2
(
Cdtype
::
float2num
(
c0
),
Cdtype
::
float2num
(
c1
));
...
...
@@ -1800,11 +1814,6 @@ __global__ void Marlin(
res
=
__hmul2
(
res
,
tmp_scale
);
}
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
if
(
!
mul_topk_weights
)
{
res
=
__hmul2
(
res
,
global_scale
);
}
}
if
(
has_bias
&&
last
)
{
c_scalar_t2
tmp_bias
=
b_bias
[
0
];
if
constexpr
(
m_block_size_8
)
{
...
...
csrc/moe/marlin_moe_wna16/ops.cu
View file @
148a5c12
...
...
@@ -382,7 +382,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
const
int4
*
bias_ptr
=
(
const
int4
*
)
b_bias
;
const
float
*
a_s_ptr
=
(
const
float
*
)
a_s
;
const
int4
*
b_s_ptr
=
(
const
int4
*
)
b_s
;
const
uint16_
t
*
g_s_ptr
=
(
const
uint16_
t
*
)
g_s
;
const
floa
t
*
g_s_ptr
=
(
const
floa
t
*
)
g_s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
const
int
*
perm_ptr
=
(
const
int
*
)
perm
;
...
...
@@ -759,7 +759,7 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
,
"global_scale can only be used for nvfp4 format."
);
}
else
{
global_scale
=
torch
::
empty
({
0
},
options
);
global_scale
=
torch
::
empty
({
0
},
options
_fp32
);
TORCH_CHECK
(
!
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
),
"the global_scale parameter must be passed for nvfp4 format."
);
}
...
...
@@ -842,8 +842,8 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK
(
a_scales
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"scalar type of a_scales must be float"
);
TORCH_CHECK
(
global_scale
.
scalar_type
()
==
c
.
s
calar
_t
ype
()
,
"scalar type of global_scale must be
the same with c
"
);
TORCH_CHECK
(
global_scale
.
scalar_type
()
==
at
::
S
calar
T
ype
::
Float
,
"scalar type of global_scale must be
float
"
);
if
(
a_type
.
size_bits
()
==
16
)
{
TORCH_CHECK
(
a
.
scalar_type
()
==
c
.
scalar_type
(),
...
...
csrc/quantization/activation_kernels.cu
View file @
148a5c12
...
...
@@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() {
}
__device__
__forceinline__
float
clip
(
float
v
,
float
mmin
,
float
mmax
)
{
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
return
fminf
(
mmax
,
fmaxf
(
v
,
mmin
));
#else
#endif
}
__device__
__forceinline__
__nv_bfloat16
clip
(
__nv_bfloat16
v
,
...
...
csrc/quantization/marlin/kernel.h
View file @
148a5c12
...
...
@@ -13,7 +13,7 @@
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const
uint16_
t *__restrict__ global_scale_ptr, \
const
floa
t *__restrict__ global_scale_ptr,
\
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
...
...
csrc/quantization/marlin/marlin.cu
View file @
148a5c12
...
...
@@ -57,7 +57,7 @@ torch::Tensor marlin_gemm(
int64_t
size_k
,
bool
is_k_full
,
bool
use_atomic_add
,
bool
use_fp32_reduce
,
bool
is_zp_float
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_gemm(..) requires CUDA_ARCH >=
8.0
"
);
"marlin_gemm(..) requires CUDA_ARCH >=
7.5
"
);
return
torch
::
empty
({
1
,
1
});
}
...
...
@@ -356,7 +356,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
const
int4
*
bias_ptr
=
(
const
int4
*
)
b_bias
;
const
float
*
a_s_ptr
=
(
const
float
*
)
a_s
;
const
int4
*
b_s_ptr
=
(
const
int4
*
)
b_s
;
const
uint16_
t
*
g_s_ptr
=
(
const
uint16_
t
*
)
g_s
;
const
floa
t
*
g_s_ptr
=
(
const
floa
t
*
)
g_s
;
const
int4
*
zp_ptr
=
(
const
int4
*
)
zp
;
const
int
*
g_idx_ptr
=
(
const
int
*
)
g_idx
;
...
...
@@ -751,7 +751,7 @@ torch::Tensor marlin_gemm(
TORCH_CHECK
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
,
"global_scale can only be used for nvfp4 format."
);
}
else
{
global_scale
=
torch
::
empty
({
0
},
options
);
global_scale
=
torch
::
empty
({
0
},
options
_fp32
);
TORCH_CHECK
(
!
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
),
"the global_scale parameter must be passed for nvfp4 format."
);
}
...
...
@@ -832,8 +832,8 @@ torch::Tensor marlin_gemm(
TORCH_CHECK
(
a_scales
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"scalar type of a_scales must be float"
);
TORCH_CHECK
(
global_scale
.
scalar_type
()
==
c
.
s
calar
_t
ype
()
,
"scalar type of global_scale must be
the same with c
"
);
TORCH_CHECK
(
global_scale
.
scalar_type
()
==
at
::
S
calar
T
ype
::
Float
,
"scalar type of global_scale must be
float
"
);
if
(
a_type
.
size_bits
()
==
16
)
{
TORCH_CHECK
(
a
.
scalar_type
()
==
c
.
scalar_type
(),
...
...
csrc/quantization/marlin/marlin_template.h
View file @
148a5c12
...
...
@@ -251,8 +251,8 @@ __global__ void Marlin(
const
float
*
__restrict__
a_scales_ptr
,
// fp16 quantization scales. shape (k/groupsize, n)
const
int4
*
__restrict__
scales_ptr
,
// f
p16
global scale (for nvfp4// only)
const
uint16_
t
*
__restrict__
global_scale_ptr
,
// f
loat
global scale (for nvfp4// only)
const
floa
t
*
__restrict__
global_scale_ptr
,
// 4bit packed zero-points of shape
// (k/groupsize, n/pack_factor)
const
int4
*
__restrict__
zp_ptr
,
...
...
@@ -292,7 +292,13 @@ __global__ void Marlin(
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
constexpr
bool
use_fp16_accum
=
a_type_id
==
vllm
::
kFloat16
.
id
();
constexpr
auto
num_bits
=
vllm
::
ScalarType
::
from_id
(
b_type_id
).
size_bits
();
// Disable use_fp16_accum for NVFP4 and cases when group_size == -1 &&
// num_bits == 4
constexpr
bool
use_fp16_accum
=
a_type_id
==
vllm
::
kFloat16
.
id
()
&&
(
!
(
b_type_id
==
vllm
::
kFE2M1f
.
id
()
&&
s_type_id
==
vllm
::
kFE4M3fn
.
id
())
&&
!
(
group_blocks
==
-
1
&&
num_bits
==
4
));
#else
constexpr
bool
use_fp16_accum
=
false
;
#endif
...
...
@@ -342,11 +348,10 @@ __global__ void Marlin(
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
b_type
==
vllm
::
kU8
);
c_scalar_t2
global_scale
;
float
global_scale
_f32
=
1.0
f
;
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
uint16_t
val
=
global_scale_ptr
[
0
];
global_scale
=
Cdtype
::
num2num2
(
*
reinterpret_cast
<
c_scalar_t
*>
(
&
val
));
global_scale_f32
=
global_scale_ptr
[
0
];
}
constexpr
bool
has_act_order
=
group_blocks
==
0
;
...
...
@@ -1644,6 +1649,10 @@ __global__ void Marlin(
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
,
FragS
&
b_bias
)
{
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
c0
*=
global_scale_f32
;
c1
*=
global_scale_f32
;
}
c_scalar_t2
res
=
Cdtype
::
nums2num2
(
Cdtype
::
float2num
(
c0
),
Cdtype
::
float2num
(
c1
));
...
...
@@ -1659,10 +1668,6 @@ __global__ void Marlin(
}
res
=
__hmul2
(
res
,
tmp_scale
);
}
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
)
{
res
=
__hmul2
(
res
,
global_scale
);
}
if
(
has_bias
&&
last
)
{
c_scalar_t2
tmp_bias
=
b_bias
[
0
];
if
constexpr
(
m_block_size_8
)
{
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
View file @
148a5c12
...
...
@@ -27,10 +27,19 @@ def is_fp4_marlin_supported():
return
current_platform
.
has_device_capability
(
75
)
def
_nvfp4_compute_scale_factor
(
marlin_scales
:
torch
.
Tensor
)
->
float
:
def
_nvfp4_compute_scale_factor
(
marlin_scales
:
torch
.
Tensor
,
a_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
float
:
"""Compute the power-of-2 scale_factor needed so that all non-zero
values in marlin_scales * 2^7 are >= 2 after rescaling.
Returns a Python float (power of 2, >= 1.0)."""
# Since half has a smaller dynamic range compared to bfloat16,
# no rescaling is applied here if active dtype is half.
if
a_dtype
is
not
None
and
a_dtype
==
torch
.
half
:
return
1.0
ws_float
=
marlin_scales
.
float
()
*
(
2
**
7
)
nonzero_mask
=
ws_float
>
0
if
nonzero_mask
.
any
():
...
...
@@ -44,6 +53,7 @@ def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
def
nvfp4_marlin_process_scales
(
marlin_scales
:
torch
.
Tensor
,
scale_factor
:
float
|
None
=
None
,
a_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
float
]:
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
...
...
@@ -91,7 +101,7 @@ def nvfp4_marlin_process_scales(
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
# The caller must compensate by dividing global_scale by scale_factor.
if
scale_factor
is
None
:
scale_factor
=
_nvfp4_compute_scale_factor
(
marlin_scales
)
scale_factor
=
_nvfp4_compute_scale_factor
(
marlin_scales
,
a_dtype
)
if
scale_factor
>
1.0
:
marlin_scales
=
(
marlin_scales
.
float
()
*
scale_factor
).
to
(
torch
.
half
)
...
...
@@ -119,12 +129,14 @@ def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
return
marlin_scales
def
nvfp4_marlin_process_global_scale
(
global_scale
):
assert
global_scale
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
]
def
nvfp4_marlin_process_global_scale
(
global_scale
,
a_dtype
:
torch
.
dtype
|
None
=
None
):
if
a_dtype
is
None
:
a_dtype
=
global_scale
.
dtype
assert
a_dtype
in
[
torch
.
half
,
torch
.
bfloat16
]
fp4_exponent
=
2
if
global_scale
.
dtype
==
torch
.
half
:
if
a_
dtype
==
torch
.
half
:
target_exponent
=
5
elif
global_scale
.
dtype
==
torch
.
bfloat16
:
elif
a_
dtype
==
torch
.
bfloat16
:
target_exponent
=
8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
# exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
...
...
@@ -244,11 +256,15 @@ def prepare_fp4_layer_for_marlin(
)
if
is_nvfp4
:
weight_scale
,
scale_factor
=
nvfp4_marlin_process_scales
(
weight_scale
)
weight_scale
,
scale_factor
=
nvfp4_marlin_process_scales
(
weight_scale
,
a_dtype
=
param_dtype
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight_global_scale
=
layer
.
weight_global_scale
.
to
(
param_dtype
)
weight_global_scale
=
nvfp4_marlin_process_global_scale
(
weight_global_scale
)
weight_global_scale
=
layer
.
weight_global_scale
.
to
(
torch
.
float32
)
weight_global_scale
=
nvfp4_marlin_process_global_scale
(
weight_global_scale
,
param_dtype
)
weight_global_scale
=
weight_global_scale
/
scale_factor
layer
.
weight_global_scale
=
torch
.
nn
.
Parameter
(
weight_global_scale
,
requires_grad
=
False
...
...
@@ -339,7 +355,6 @@ def prepare_nvfp4_moe_layer_for_marlin(
scales
:
torch
.
Tensor
,
g_scales
:
torch
.
Tensor
,
name
:
str
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
scales
=
scales
.
to
(
param_dtype
)
g_scales
=
g_scales
.
to
(
param_dtype
)
tensor_list
=
[]
num_shards
=
2
if
is_act_and_mul
else
1
...
...
@@ -350,7 +365,7 @@ def prepare_nvfp4_moe_layer_for_marlin(
# All experts share one global_scale, so compute the max
# scale_factor across all experts first, then apply uniformly.
combined_scale_factor
=
_nvfp4_compute_scale_factor
(
scales
)
combined_scale_factor
=
_nvfp4_compute_scale_factor
(
scales
,
param_dtype
)
for
i
in
range
(
E
):
scale
=
scales
[
i
].
T
...
...
@@ -362,12 +377,12 @@ def prepare_nvfp4_moe_layer_for_marlin(
is_a_8bit
=
is_a_8bit
,
)
marlin_scales
,
_
=
nvfp4_marlin_process_scales
(
marlin_scales
,
scale_factor
=
combined_scale_factor
marlin_scales
,
scale_factor
=
combined_scale_factor
,
a_dtype
=
param_dtype
)
tensor_list
.
append
(
marlin_scales
)
scales
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
g_scales
=
nvfp4_marlin_process_global_scale
(
g_scales
)
g_scales
=
nvfp4_marlin_process_global_scale
(
g_scales
,
param_dtype
)
g_scales
=
g_scales
/
combined_scale_factor
return
scales
,
g_scales
...
...
@@ -438,7 +453,7 @@ def prepare_moe_fp4_layer_for_marlin(
scales
=
scales
.
view
(
torch
.
float8_e8m0fnu
)
scales
=
scales
.
to
(
param_dtype
)
if
is_nvfp4
:
global_scale
=
getattr
(
layer
,
name
+
"_weight_scale_2"
)
.
to
(
param_dtype
)
global_scale
=
getattr
(
layer
,
name
+
"_weight_scale_2"
)
tensor_list
=
[]
if
"w13"
in
name
:
...
...
@@ -449,7 +464,7 @@ def prepare_moe_fp4_layer_for_marlin(
# For NVFP4: compute unified scale_factor across all experts
combined_scale_factor
=
None
if
is_nvfp4
:
combined_scale_factor
=
_nvfp4_compute_scale_factor
(
scales
)
combined_scale_factor
=
_nvfp4_compute_scale_factor
(
scales
,
param_dtype
)
for
i
in
range
(
e
):
scale
=
scales
[
i
].
T
...
...
@@ -463,7 +478,9 @@ def prepare_moe_fp4_layer_for_marlin(
)
if
is_nvfp4
:
marlin_scales
,
_
=
nvfp4_marlin_process_scales
(
marlin_scales
,
scale_factor
=
combined_scale_factor
marlin_scales
,
scale_factor
=
combined_scale_factor
,
a_dtype
=
param_dtype
,
)
else
:
marlin_scales
=
mxfp4_marlin_process_scales
(
...
...
@@ -477,7 +494,7 @@ def prepare_moe_fp4_layer_for_marlin(
if
is_nvfp4
:
assert
combined_scale_factor
is
not
None
global_scale
=
nvfp4_marlin_process_global_scale
(
global_scale
)
global_scale
=
nvfp4_marlin_process_global_scale
(
global_scale
,
param_dtype
)
global_scale
=
global_scale
/
combined_scale_factor
global_scale
=
torch
.
nn
.
Parameter
(
global_scale
,
requires_grad
=
False
)
setattr
(
layer
,
name
+
"_weight_scale_2"
,
global_scale
)
...
...
@@ -665,7 +682,7 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
)
marlin_scales
,
scale_factor
=
nvfp4_marlin_process_scales
(
marlin_scales
)
global_scale
=
nvfp4_marlin_process_global_scale
(
global_scale
)
global_scale
=
nvfp4_marlin_process_global_scale
(
global_scale
)
.
to
(
torch
.
float32
)
global_scale
=
global_scale
/
scale_factor
return
weight_ref
.
T
,
marlin_qweight
,
marlin_scales
,
global_scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment