# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by collum parallel or enabling merged weights
ifintermediate_size%block_n!=0:
raiseValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
iftp_size>1:
# Required by row parallel
ifintermediate_size%block_k!=0:
raiseValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
# WEIGHTS
w13_weight=torch.nn.Parameter(
w13_weight=torch.nn.Parameter(
...
@@ -374,21 +476,45 @@ class Fp8MoEMethod:
...
@@ -374,21 +476,45 @@ class Fp8MoEMethod:
set_weight_attrs(w2_weight,extra_weight_attrs)
set_weight_attrs(w2_weight,extra_weight_attrs)
# WEIGHT_SCALES
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
ifself.block_quant:
# They will be combined to a single scale after weight loading.