"next_docs/en/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "b2887ca0aa2d908392746d6e758db615b4ade27b"
Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
This diff is collapsed.
...@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune( ...@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False): def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
assert len(x.shape) == 2 assert len(x.shape) == 2
assert ( assert x.shape == residual.shape and x.dtype == residual.dtype
x.shape == residual.shape and x.dtype == residual.dtype
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
output, mid = torch.empty_like(x), torch.empty_like(x) output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape bs, hidden_dim = x.shape
if autotune: if autotune:
......
This diff is collapsed.
This diff is collapsed.
...@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module): ...@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module):
) )
# Flashinfer assumes w31 format for w13_weight. Same for the scales. # Flashinfer assumes w31 format for w13_weight. Same for the scales.
if ( if should_use_flashinfer_trtllm_moe():
should_use_flashinfer_trtllm_moe()
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
):
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment