@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer:Module,
num_experts:int,
hidden_size:int,
intermediate_size:int,
intermediate_size_per_partition:int,
params_dtype:torch.dtype,
**extra_weight_attrs,
):
...
...
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
ifintermediate_size%block_n!=0:
ifintermediate_size_per_partition%block_n!=0:
raiseValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
iftp_size>1:
# Required by row parallel
ifintermediate_size%block_k!=0:
ifintermediate_size_per_partition%block_k!=0:
raiseValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
@@ -527,7 +534,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer:Module,
num_experts:int,
hidden_size:int,
intermediate_size:int,
intermediate_size_per_partition:int,
params_dtype:torch.dtype,
**extra_weight_attrs,
):
...
...
@@ -543,18 +550,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
ifintermediate_size%block_n!=0:
ifintermediate_size_per_partition%block_n!=0:
raiseValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
iftp_size>1:
# Required by row parallel
ifintermediate_size%block_k!=0:
ifintermediate_size_per_partition%block_k!=0:
raiseValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
...
...
@@ -564,7 +571,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
w13_weight=torch.nn.Parameter(
torch.empty(
num_experts,
2*intermediate_size,
2*intermediate_size_per_partition,
hidden_size//8,
dtype=params_dtype,
),
...
...
@@ -572,20 +579,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):