Unverified Commit 9db4650e authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[MoE Refactor] Add more MoE layer tests (#39349)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
Signed-off-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 5e584ce9
......@@ -77,8 +77,8 @@ def _worker_parallel_launch(
*args: Any,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.accelerator.set_device_index(local_rank)
device = torch.device("cuda", local_rank)
torch.accelerator.set_device_index(device)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
......
......@@ -65,8 +65,8 @@ fp8_dtype = torch.float8_e4m3fn # current_platform.fp8_dtype
SHAPE_COMBOS = [
(1, 128, 256),
(32, 1024, 512),
(222, 2048, 2048),
(32, 512, 512),
(222, 1024, 2048),
]
MAX_M = max([x[0] for x in SHAPE_COMBOS])
......@@ -95,7 +95,7 @@ if has_flashinfer_nvlink_one_sided():
BACKENDS += ["flashinfer_nvlink_one_sided"]
if has_deep_ep():
BACKENDS += ["deepep_low_latency", "deepep_high_throughput"]
BACKENDS += ["deepep_high_throughput", "deepep_low_latency"]
if has_nixl_ep():
BACKENDS += ["nixl_ep"]
......@@ -103,6 +103,7 @@ if has_nixl_ep():
QUANT_METHODS = [
None,
"fp8",
"fp8_blocked",
"modelopt_fp8",
"modelopt_fp4",
]
......@@ -114,10 +115,21 @@ BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = {
"mori": {None, "fp8", "modelopt_fp8"},
"flashinfer_nvlink_two_sided": {None, "modelopt_fp8", "modelopt_fp4"},
"flashinfer_nvlink_one_sided": {None, "modelopt_fp8", "modelopt_fp4"},
"deepep_low_latency": {None, "modelopt_fp8", "modelopt_fp4"},
"deepep_high_throughput": {None, "fp8", "modelopt_fp8", "modelopt_fp4"},
"deepep_low_latency": {None, "fp8_blocked", "modelopt_fp4"},
"deepep_high_throughput": {None, "fp8_blocked", "modelopt_fp8", "modelopt_fp4"}, # noqa: E501
"nixl_ep": {None, "fp8", "modelopt_fp8"},
}
# Map from backend -> (DP/EP support, DP support, TP support)
BACKEND_EP_DP_TP_SUPPORT: dict[str, tuple[bool, bool, bool]] = {
"allgather_reducescatter": (True, True, True),
"mori": (True, False, False),
"flashinfer_nvlink_two_sided": (False, True, False),
"flashinfer_nvlink_one_sided": (False, True, False),
"deepep_low_latency": (True, False, False),
"deepep_high_throughput": (True, False, False),
"nixl_ep": (True, False, False),
}
# fmt: on
# Which quantization methods support EPLB.
......@@ -424,27 +436,35 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]:
f"Skipping unsupported K {config.k} in {config.backend} w/o EP.",
)
if config.enable_eplb and config.ep_size == 1:
return False, "EPLB requires EP."
if config.backend is not None:
supports_ep_dp, supports_dp, supports_tp = BACKEND_EP_DP_TP_SUPPORT[
config.backend
]
if config.enable_eplb and config.quantization not in EPLB_SUPPORTED_QUANTS:
return False, f"EPLB not supported with {config.quantization} quantization."
if config.tp_size > 1 and not supports_tp:
return False, f"{config.backend} does not support TP."
if config.enable_eplb and config.backend not in EPLB_SUPPORTED_BACKENDS:
return False, f"EPLB not supported with {config.backend}."
if config.dp_size > 1 and config.ep_size == 1 and not supports_dp:
return False, f"{config.backend} does not support DP."
if (
config.backend is not None
and config.backend.startswith("flashinfer_nvlink")
and config.ep_size > 1
):
return False, "flashinfer_nvlink EP not yet supported."
if config.dp_size > 1 and config.ep_size > 1 and not supports_ep_dp:
return False, f"{config.backend} does not support EP/DP."
else:
if config.tp_size > 1 or config.ep_size > 1 or config.dp_size > 1:
return False, "An all2all backend is required for parallelism."
if config.enable_eplb and config.num_experts % config.dp_size != 0:
return False, "EPLB requires num_experts divisible by ep_size"
if config.enable_eplb:
if config.ep_size == 1:
return False, "EPLB requires EP."
if config.enable_eplb and config.ep_size == 1:
return False, "EPLB only works with EP+DP"
if config.quantization not in EPLB_SUPPORTED_QUANTS:
return False, f"EPLB not supported with {config.quantization} quantization."
if config.backend not in EPLB_SUPPORTED_BACKENDS:
return False, f"EPLB not supported with {config.backend}."
if config.num_experts % config.dp_size != 0:
return False, "EPLB requires num_experts divisible by ep_size"
# Disable fp4 tests until flashinfer is updated or the Dockerfile is
# modified to install cublasLt.h. See #39525.
......@@ -507,27 +527,48 @@ class QuantizedWeights:
def _quantize_fp8_halves(
w1: torch.Tensor,
w2: torch.Tensor,
block_shape: list[int] | None = None,
) -> QuantizedWeights:
"""Quantize w13 gate/up halves separately to FP8, producing per-shard scales."""
half = w1.shape[1] // 2
w1q_a, w1s_a, _ = moe_quantize_weights(
w1[:, :half, :], None, fp8_dtype, False, None
w1[:, :half, :],
None,
fp8_dtype,
False,
block_shape,
)
w1q_b, w1s_b, _ = moe_quantize_weights(
w1[:, half:, :], None, fp8_dtype, False, None
w1[:, half:, :],
None,
fp8_dtype,
False,
block_shape,
)
assert w1s_a is not None and w1s_b is not None
w2q, w2s, _ = moe_quantize_weights(w2, None, fp8_dtype, False, None)
w2q, w2s, _ = moe_quantize_weights(w2, None, fp8_dtype, False, block_shape)
assert w2s is not None
if block_shape is not None:
# Blocked quantization: scales have shape (E, n_tiles, k_tiles)
# Concatenate gate and up scales along the n_tiles dimension (dim=1)
# to match the concatenation of gate and up weights
w13_weight_scale = torch.cat([w1s_a, w1s_b], dim=1)
# w2 scales keep their blocked shape (E, k_tiles, n_tiles)
w2_weight_scale = w2s
else:
# Non-blocked quantization: scales have shape (E, 1, 1)
# Each w1s_x is (E, 1, 1) -> reshape to (E, 1), cat to (E, 2)
w13_weight_scale = torch.cat([w1s_a.view(-1, 1), w1s_b.view(-1, 1)], dim=1)
# w2s is (E, 1, 1) -> reshape to (E,)
w2_weight_scale = w2s.view(-1)
return QuantizedWeights(
w13_weight=torch.cat([w1q_a, w1q_b], dim=1),
w2_weight=w2q,
# Each w1s_x is (E, 1, 1) -> reshape to (E, 1), cat to (E, 2)
w13_weight_scale=torch.cat([w1s_a.view(-1, 1), w1s_b.view(-1, 1)], dim=1),
# w2s is (E, 1, 1) -> reshape to (E,)
w2_weight_scale=w2s.view(-1),
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
)
......@@ -536,7 +577,7 @@ def quantization_to_quant_dtype(
) -> torch.dtype | str | None:
if quantization is None:
return None
elif quantization in ["fp8", "modelopt_fp8"]:
elif quantization in ["fp8", "fp8_blocked", "modelopt_fp8"]:
return fp8_dtype
elif quantization in ["modelopt_fp4"]:
return "nvfp4"
......@@ -558,6 +599,12 @@ def make_quant_config(
if quantization == "fp8":
return Fp8Config(True), _quantize_fp8_halves(w1, w2)
if quantization == "fp8_blocked":
block_shape = [128, 128]
return Fp8Config(True, weight_block_size=block_shape), _quantize_fp8_halves(
w1, w2, block_shape
)
if quantization == "modelopt_fp8":
qw = _quantize_fp8_halves(w1, w2)
# why?
......@@ -896,11 +943,13 @@ def make_fused_moe_layer(
**kwargs,
)
weight_scale_name = getattr(layer.quant_method, "weight_scale_name", "weight_scale")
for name, value in [
("w13_weight", qw.w13_weight),
("w2_weight", qw.w2_weight),
("w13_weight_scale", qw.w13_weight_scale),
("w2_weight_scale", qw.w2_weight_scale),
(f"w13_{weight_scale_name}", qw.w13_weight_scale),
(f"w2_{weight_scale_name}", qw.w2_weight_scale),
("w13_weight_scale_2", qw.w13_weight_scale_2),
("w2_weight_scale_2", qw.w2_weight_scale_2),
("w13_input_scale", qw.w13_input_scale),
......@@ -922,7 +971,7 @@ def make_fake_moe_layer(
top_k: int,
global_num_experts: int,
in_dtype: torch.dtype,
quant_dtype: torch.dtype | None,
quantization: str | None,
renormalize: bool = False,
shared_experts_config: SharedExpertsConfig | None = None,
use_grouped_topk: bool = False,
......@@ -948,6 +997,7 @@ def make_fake_moe_layer(
dp_size: int = 1,
ep_size: int = 1,
) -> Callable:
quant_dtype = None
activation = MoEActivation.from_str(activation)
router = create_fused_moe_router(
......@@ -1139,7 +1189,6 @@ def _test_body_eplb(
routed_output_transform=routed_output_transform,
)
# Necessary?
if eplb_moe_layer._expert_map is not None:
eplb_moe_layer._expert_map = eplb_moe_layer._expert_map.to(device)
......@@ -1267,6 +1316,7 @@ def _run_one_config(
gate = test_data.gate
routed_input_transform = test_data.routed_input_transform
routed_output_transform = test_data.routed_output_transform
activation = "silu"
baseline_layer = make_fake_moe_layer(
w1=w1,
......@@ -1274,7 +1324,7 @@ def _run_one_config(
top_k=top_k,
global_num_experts=num_experts,
in_dtype=in_dtype,
quant_dtype=None, # quantization_to_quant_dtype(quantization),
quantization=quantization,
renormalize=False,
shared_experts_config=shared_experts_config,
gate=gate,
......@@ -1284,6 +1334,7 @@ def _run_one_config(
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
activation=activation,
)
baseline_output = baseline_layer(hidden_states, router_logits)
......@@ -1328,9 +1379,9 @@ def _run_one_config(
gate=gate,
routed_input_transform=routed_input_transform,
routed_output_transform=routed_output_transform,
activation=activation,
)
# Necessary?
if moe_layer._expert_map is not None:
moe_layer._expert_map = moe_layer._expert_map.to(device)
......@@ -1377,13 +1428,17 @@ def _run_one_config(
atol, rtol = 7.6e-2, 7.6e-2
else:
atol, rtol = 3.5e-2, 3.5e-2
elif quantization in ("fp8", "modelopt_fp8"):
elif quantization in ("fp8", "fp8_blocked", "modelopt_fp8"):
atol, rtol = 6e-2, 6e-2
elif quantization == "modelopt_fp4":
if k >= 2048:
atol, rtol = 7.6e-2, 7.6e-2
atol = rtol = 1e-1 + (k * 1e-4)
else:
atol, rtol = 6e-2, 6e-2
elif quantization == "modelopt_fp4":
atol = rtol = 1e-1 + k * 5e-4
atol = rtol = 1e-1
if backend == "allgather_reducescatter" and tp_size > 1:
atol += 2e-1
rtol += 2e-1
else:
atol, rtol = 6e-2, 6e-2
......
......@@ -990,8 +990,7 @@ class FusedMoEParallelConfig:
@property
def use_batched_activation_format(self):
# TODO(bnell): nixl also uses batched format
return self.use_deepep_ll_kernels
return self.use_deepep_ll_kernels or self.use_nixl_ep_kernels
@property
def use_ag_rs_all2all_kernels(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment