Unverified Commit 9db4650e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Add more MoE layer tests (#39349)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Signed-off-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 5e584ce9
...@@ -77,8 +77,8 @@ def _worker_parallel_launch( ...@@ -77,8 +77,8 @@ def _worker_parallel_launch(
*args: Any, *args: Any,
) -> None: ) -> None:
rank = node_rank * world_local_size + local_rank rank = node_rank * world_local_size + local_rank
torch.accelerator.set_device_index(local_rank)
device = torch.device("cuda", local_rank) device = torch.device("cuda", local_rank)
torch.accelerator.set_device_index(device)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl", backend="cpu:gloo,cuda:nccl",
init_method=init_method, init_method=init_method,
......
...@@ -65,8 +65,8 @@ fp8_dtype = torch.float8_e4m3fn # current_platform.fp8_dtype ...@@ -65,8 +65,8 @@ fp8_dtype = torch.float8_e4m3fn # current_platform.fp8_dtype
SHAPE_COMBOS = [ SHAPE_COMBOS = [
(1, 128, 256), (1, 128, 256),
(32, 1024, 512), (32, 512, 512),
(222, 2048, 2048), (222, 1024, 2048),
] ]
MAX_M = max([x[0] for x in SHAPE_COMBOS]) MAX_M = max([x[0] for x in SHAPE_COMBOS])
...@@ -95,7 +95,7 @@ if has_flashinfer_nvlink_one_sided(): ...@@ -95,7 +95,7 @@ if has_flashinfer_nvlink_one_sided():
BACKENDS += ["flashinfer_nvlink_one_sided"] BACKENDS += ["flashinfer_nvlink_one_sided"]
if has_deep_ep(): if has_deep_ep():
BACKENDS += ["deepep_low_latency", "deepep_high_throughput"] BACKENDS += ["deepep_high_throughput", "deepep_low_latency"]
if has_nixl_ep(): if has_nixl_ep():
BACKENDS += ["nixl_ep"] BACKENDS += ["nixl_ep"]
...@@ -103,6 +103,7 @@ if has_nixl_ep(): ...@@ -103,6 +103,7 @@ if has_nixl_ep():
QUANT_METHODS = [ QUANT_METHODS = [
None, None,
"fp8", "fp8",
"fp8_blocked",
"modelopt_fp8", "modelopt_fp8",
"modelopt_fp4", "modelopt_fp4",
] ]
...@@ -114,10 +115,21 @@ BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = { ...@@ -114,10 +115,21 @@ BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = {
"mori": {None, "fp8", "modelopt_fp8"}, "mori": {None, "fp8", "modelopt_fp8"},
"flashinfer_nvlink_two_sided": {None, "modelopt_fp8", "modelopt_fp4"}, "flashinfer_nvlink_two_sided": {None, "modelopt_fp8", "modelopt_fp4"},
"flashinfer_nvlink_one_sided": {None, "modelopt_fp8", "modelopt_fp4"}, "flashinfer_nvlink_one_sided": {None, "modelopt_fp8", "modelopt_fp4"},
"deepep_low_latency": {None, "modelopt_fp8", "modelopt_fp4"}, "deepep_low_latency": {None, "fp8_blocked", "modelopt_fp4"},
"deepep_high_throughput": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, "deepep_high_throughput": {None, "fp8_blocked", "modelopt_fp8", "modelopt_fp4"}, # noqa: E501
"nixl_ep": {None, "fp8", "modelopt_fp8"}, "nixl_ep": {None, "fp8", "modelopt_fp8"},
} }
# Map from backend -> (DP/EP support, DP support, TP support)
BACKEND_EP_DP_TP_SUPPORT: dict[str, tuple[bool, bool, bool]] = {
"allgather_reducescatter": (True, True, True),
"mori": (True, False, False),
"flashinfer_nvlink_two_sided": (False, True, False),
"flashinfer_nvlink_one_sided": (False, True, False),
"deepep_low_latency": (True, False, False),
"deepep_high_throughput": (True, False, False),
"nixl_ep": (True, False, False),
}
# fmt: on # fmt: on
# Which quantization methods support EPLB. # Which quantization methods support EPLB.
...@@ -424,27 +436,35 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: ...@@ -424,27 +436,35 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]:
f"Skipping unsupported K {config.k} in {config.backend} w/o EP.", f"Skipping unsupported K {config.k} in {config.backend} w/o EP.",
) )
if config.enable_eplb and config.ep_size == 1: if config.backend is not None:
return False, "EPLB requires EP." supports_ep_dp, supports_dp, supports_tp = BACKEND_EP_DP_TP_SUPPORT[
config.backend
]
if config.enable_eplb and config.quantization not in EPLB_SUPPORTED_QUANTS: if config.tp_size > 1 and not supports_tp:
return False, f"EPLB not supported with {config.quantization} quantization." return False, f"{config.backend} does not support TP."
if config.enable_eplb and config.backend not in EPLB_SUPPORTED_BACKENDS: if config.dp_size > 1 and config.ep_size == 1 and not supports_dp:
return False, f"EPLB not supported with {config.backend}." return False, f"{config.backend} does not support DP."
if ( if config.dp_size > 1 and config.ep_size > 1 and not supports_ep_dp:
config.backend is not None return False, f"{config.backend} does not support EP/DP."
and config.backend.startswith("flashinfer_nvlink") else:
and config.ep_size > 1 if config.tp_size > 1 or config.ep_size > 1 or config.dp_size > 1:
): return False, "An all2all backend is required for parallelism."
return False, "flashinfer_nvlink EP not yet supported."
if config.enable_eplb and config.num_experts % config.dp_size != 0: if config.enable_eplb:
return False, "EPLB requires num_experts divisible by ep_size" if config.ep_size == 1:
return False, "EPLB requires EP."
if config.enable_eplb and config.ep_size == 1: if config.quantization not in EPLB_SUPPORTED_QUANTS:
return False, "EPLB only works with EP+DP" return False, f"EPLB not supported with {config.quantization} quantization."
if config.backend not in EPLB_SUPPORTED_BACKENDS:
return False, f"EPLB not supported with {config.backend}."
if config.num_experts % config.dp_size != 0:
return False, "EPLB requires num_experts divisible by ep_size"
# Disable fp4 tests until flashinfer is updated or the Dockerfile is # Disable fp4 tests until flashinfer is updated or the Dockerfile is
# modified to install cublasLt.h. See #39525. # modified to install cublasLt.h. See #39525.
...@@ -507,27 +527,48 @@ class QuantizedWeights: ...@@ -507,27 +527,48 @@ class QuantizedWeights:
def _quantize_fp8_halves( def _quantize_fp8_halves(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
block_shape: list[int] | None = None,
) -> QuantizedWeights: ) -> QuantizedWeights:
"""Quantize w13 gate/up halves separately to FP8, producing per-shard scales.""" """Quantize w13 gate/up halves separately to FP8, producing per-shard scales."""
half = w1.shape[1] // 2 half = w1.shape[1] // 2
w1q_a, w1s_a, _ = moe_quantize_weights( w1q_a, w1s_a, _ = moe_quantize_weights(
w1[:, :half, :], None, fp8_dtype, False, None w1[:, :half, :],
None,
fp8_dtype,
False,
block_shape,
) )
w1q_b, w1s_b, _ = moe_quantize_weights( w1q_b, w1s_b, _ = moe_quantize_weights(
w1[:, half:, :], None, fp8_dtype, False, None w1[:, half:, :],
None,
fp8_dtype,
False,
block_shape,
) )
assert w1s_a is not None and w1s_b is not None assert w1s_a is not None and w1s_b is not None
w2q, w2s, _ = moe_quantize_weights(w2, None, fp8_dtype, False, None) w2q, w2s, _ = moe_quantize_weights(w2, None, fp8_dtype, False, block_shape)
assert w2s is not None assert w2s is not None
if block_shape is not None:
# Blocked quantization: scales have shape (E, n_tiles, k_tiles)
# Concatenate gate and up scales along the n_tiles dimension (dim=1)
# to match the concatenation of gate and up weights
w13_weight_scale = torch.cat([w1s_a, w1s_b], dim=1)
# w2 scales keep their blocked shape (E, k_tiles, n_tiles)
w2_weight_scale = w2s
else:
# Non-blocked quantization: scales have shape (E, 1, 1)
# Each w1s_x is (E, 1, 1) -> reshape to (E, 1), cat to (E, 2)
w13_weight_scale = torch.cat([w1s_a.view(-1, 1), w1s_b.view(-1, 1)], dim=1)
# w2s is (E, 1, 1) -> reshape to (E,)
w2_weight_scale = w2s.view(-1)
return QuantizedWeights( return QuantizedWeights(
w13_weight=torch.cat([w1q_a, w1q_b], dim=1), w13_weight=torch.cat([w1q_a, w1q_b], dim=1),
w2_weight=w2q, w2_weight=w2q,
# Each w1s_x is (E, 1, 1) -> reshape to (E, 1), cat to (E, 2) w13_weight_scale=w13_weight_scale,
w13_weight_scale=torch.cat([w1s_a.view(-1, 1), w1s_b.view(-1, 1)], dim=1), w2_weight_scale=w2_weight_scale,
# w2s is (E, 1, 1) -> reshape to (E,)
w2_weight_scale=w2s.view(-1),
) )
...@@ -536,7 +577,7 @@ def quantization_to_quant_dtype( ...@@ -536,7 +577,7 @@ def quantization_to_quant_dtype(
) -> torch.dtype | str | None: ) -> torch.dtype | str | None:
if quantization is None: if quantization is None:
return None return None
elif quantization in ["fp8", "modelopt_fp8"]: elif quantization in ["fp8", "fp8_blocked", "modelopt_fp8"]:
return fp8_dtype return fp8_dtype
elif quantization in ["modelopt_fp4"]: elif quantization in ["modelopt_fp4"]:
return "nvfp4" return "nvfp4"
...@@ -558,6 +599,12 @@ def make_quant_config( ...@@ -558,6 +599,12 @@ def make_quant_config(
if quantization == "fp8": if quantization == "fp8":
return Fp8Config(True), _quantize_fp8_halves(w1, w2) return Fp8Config(True), _quantize_fp8_halves(w1, w2)
if quantization == "fp8_blocked":
block_shape = [128, 128]
return Fp8Config(True, weight_block_size=block_shape), _quantize_fp8_halves(
w1, w2, block_shape
)
if quantization == "modelopt_fp8": if quantization == "modelopt_fp8":
qw = _quantize_fp8_halves(w1, w2) qw = _quantize_fp8_halves(w1, w2)
# why? # why?
...@@ -896,11 +943,13 @@ def make_fused_moe_layer( ...@@ -896,11 +943,13 @@ def make_fused_moe_layer(
**kwargs, **kwargs,
) )
weight_scale_name = getattr(layer.quant_method, "weight_scale_name", "weight_scale")
for name, value in [ for name, value in [
("w13_weight", qw.w13_weight), ("w13_weight", qw.w13_weight),
("w2_weight", qw.w2_weight), ("w2_weight", qw.w2_weight),
("w13_weight_scale", qw.w13_weight_scale), (f"w13_{weight_scale_name}", qw.w13_weight_scale),
("w2_weight_scale", qw.w2_weight_scale), (f"w2_{weight_scale_name}", qw.w2_weight_scale),
("w13_weight_scale_2", qw.w13_weight_scale_2), ("w13_weight_scale_2", qw.w13_weight_scale_2),
("w2_weight_scale_2", qw.w2_weight_scale_2), ("w2_weight_scale_2", qw.w2_weight_scale_2),
("w13_input_scale", qw.w13_input_scale), ("w13_input_scale", qw.w13_input_scale),
...@@ -922,7 +971,7 @@ def make_fake_moe_layer( ...@@ -922,7 +971,7 @@ def make_fake_moe_layer(
top_k: int, top_k: int,
global_num_experts: int, global_num_experts: int,
in_dtype: torch.dtype, in_dtype: torch.dtype,
quant_dtype: torch.dtype | None, quantization: str | None,
renormalize: bool = False, renormalize: bool = False,
shared_experts_config: SharedExpertsConfig | None = None, shared_experts_config: SharedExpertsConfig | None = None,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
...@@ -948,6 +997,7 @@ def make_fake_moe_layer( ...@@ -948,6 +997,7 @@ def make_fake_moe_layer(
dp_size: int = 1, dp_size: int = 1,
ep_size: int = 1, ep_size: int = 1,
) -> Callable: ) -> Callable:
quant_dtype = None
activation = MoEActivation.from_str(activation) activation = MoEActivation.from_str(activation)
router = create_fused_moe_router( router = create_fused_moe_router(
...@@ -1139,7 +1189,6 @@ def _test_body_eplb( ...@@ -1139,7 +1189,6 @@ def _test_body_eplb(
routed_output_transform=routed_output_transform, routed_output_transform=routed_output_transform,
) )
# Necessary?
if eplb_moe_layer._expert_map is not None: if eplb_moe_layer._expert_map is not None:
eplb_moe_layer._expert_map = eplb_moe_layer._expert_map.to(device) eplb_moe_layer._expert_map = eplb_moe_layer._expert_map.to(device)
...@@ -1267,6 +1316,7 @@ def _run_one_config( ...@@ -1267,6 +1316,7 @@ def _run_one_config(
gate = test_data.gate gate = test_data.gate
routed_input_transform = test_data.routed_input_transform routed_input_transform = test_data.routed_input_transform
routed_output_transform = test_data.routed_output_transform routed_output_transform = test_data.routed_output_transform
activation = "silu"
baseline_layer = make_fake_moe_layer( baseline_layer = make_fake_moe_layer(
w1=w1, w1=w1,
...@@ -1274,7 +1324,7 @@ def _run_one_config( ...@@ -1274,7 +1324,7 @@ def _run_one_config(
top_k=top_k, top_k=top_k,
global_num_experts=num_experts, global_num_experts=num_experts,
in_dtype=in_dtype, in_dtype=in_dtype,
quant_dtype=None, # quantization_to_quant_dtype(quantization), quantization=quantization,
renormalize=False, renormalize=False,
shared_experts_config=shared_experts_config, shared_experts_config=shared_experts_config,
gate=gate, gate=gate,
...@@ -1284,6 +1334,7 @@ def _run_one_config( ...@@ -1284,6 +1334,7 @@ def _run_one_config(
tp_size=tp_size, tp_size=tp_size,
ep_size=ep_size, ep_size=ep_size,
dp_size=dp_size, dp_size=dp_size,
activation=activation,
) )
baseline_output = baseline_layer(hidden_states, router_logits) baseline_output = baseline_layer(hidden_states, router_logits)
...@@ -1328,9 +1379,9 @@ def _run_one_config( ...@@ -1328,9 +1379,9 @@ def _run_one_config(
gate=gate, gate=gate,
routed_input_transform=routed_input_transform, routed_input_transform=routed_input_transform,
routed_output_transform=routed_output_transform, routed_output_transform=routed_output_transform,
activation=activation,
) )
# Necessary?
if moe_layer._expert_map is not None: if moe_layer._expert_map is not None:
moe_layer._expert_map = moe_layer._expert_map.to(device) moe_layer._expert_map = moe_layer._expert_map.to(device)
...@@ -1377,13 +1428,17 @@ def _run_one_config( ...@@ -1377,13 +1428,17 @@ def _run_one_config(
atol, rtol = 7.6e-2, 7.6e-2 atol, rtol = 7.6e-2, 7.6e-2
else: else:
atol, rtol = 3.5e-2, 3.5e-2 atol, rtol = 3.5e-2, 3.5e-2
elif quantization in ("fp8", "modelopt_fp8"): elif quantization in ("fp8", "fp8_blocked", "modelopt_fp8"):
atol, rtol = 6e-2, 6e-2
elif quantization == "modelopt_fp4":
if k >= 2048: if k >= 2048:
atol, rtol = 7.6e-2, 7.6e-2 atol = rtol = 1e-1 + (k * 1e-4)
else: else:
atol, rtol = 6e-2, 6e-2 atol = rtol = 1e-1
elif quantization == "modelopt_fp4":
atol = rtol = 1e-1 + k * 5e-4 if backend == "allgather_reducescatter" and tp_size > 1:
atol += 2e-1
rtol += 2e-1
else: else:
atol, rtol = 6e-2, 6e-2 atol, rtol = 6e-2, 6e-2
......
...@@ -990,8 +990,7 @@ class FusedMoEParallelConfig: ...@@ -990,8 +990,7 @@ class FusedMoEParallelConfig:
@property @property
def use_batched_activation_format(self): def use_batched_activation_format(self):
# TODO(bnell): nixl also uses batched format return self.use_deepep_ll_kernels or self.use_nixl_ep_kernels
return self.use_deepep_ll_kernels
@property @property
def use_ag_rs_all2all_kernels(self): def use_ag_rs_all2all_kernels(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment