Unverified Commit e8a69e4d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up fp8 support (#4230)

parent fbd56002
...@@ -55,6 +55,7 @@ jobs: ...@@ -55,6 +55,7 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
run: | run: |
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
mla-test-1-gpu-amd: mla-test-1-gpu-amd:
if: github.event.pull_request.head.repo.fork == false && github.event.pull_request.draft == false if: github.event.pull_request.head.repo.fork == false && github.event.pull_request.draft == false
......
...@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip(): if is_hip_:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=weight_scale, weight_scale=weight_scale,
...@@ -624,56 +624,9 @@ class Fp8MoEMethod: ...@@ -624,56 +624,9 @@ class Fp8MoEMethod:
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if get_bool_env_var("USE_INT4_WEIGHT"): if get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added self.process_weights_hip_int4(layer)
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id]
/ max_w13_scale_fp8
)
layer.w13_weight_scale1[expert_id][
start : start + shard_size
] *= int4_rescale
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
return return
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
padding_size, # Avoid circular import
)
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
...@@ -710,6 +663,7 @@ class Fp8MoEMethod: ...@@ -710,6 +663,7 @@ class Fp8MoEMethod:
layer.w2_weight.contiguous(), (16, 16) layer.w2_weight.contiguous(), (16, 16)
) )
return return
# If checkpoint is fp16 or bfloat16, quantize in place. # If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW) # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
...@@ -736,32 +690,7 @@ class Fp8MoEMethod: ...@@ -736,32 +690,7 @@ class Fp8MoEMethod:
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if is_hip_: if is_hip_:
if get_bool_env_var("CK_MOE"): self.process_weights_hip_scale_padding(layer)
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
# ROCm (CK_MOE): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return return
# If checkpoint is fp8, we need to handle that the # If checkpoint is fp8, we need to handle that the
...@@ -843,6 +772,57 @@ class Fp8MoEMethod: ...@@ -843,6 +772,57 @@ class Fp8MoEMethod:
) )
if is_hip_: if is_hip_:
self.process_weights_hip_scale_padding(layer)
return
def process_weights_hip_int4(self, layer: Module):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
max_w13_scale_fp8 = max_w13_scales[expert_id]
for shard_id in range(2):
if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8:
int4_rescale = (
layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8
)
layer.w13_weight_scale1[expert_id][
start : start + shard_size
] *= int4_rescale
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for expert_id in range(layer.num_experts):
layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id]
layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id]
def process_weights_hip_scale_padding(self, layer: Module, padding_size: int):
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
padding_size, # Avoid circular import
)
if get_bool_env_var("CK_MOE"): if get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), permute_weight(layer.w13_weight.data),
...@@ -869,7 +849,6 @@ class Fp8MoEMethod: ...@@ -869,7 +849,6 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
return
def apply( def apply(
self, self,
......
import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from packaging.version import Version
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
...@@ -13,18 +11,17 @@ from sglang.srt.utils import ( ...@@ -13,18 +11,17 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_cuda_version, get_cuda_version,
get_device_capability, get_device_capability,
is_cuda,
is_hip, is_hip,
) )
use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get( use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False
)
is_hip_ = is_hip() is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"): if is_hip_ and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale from aiter import gemm_a8w8_blockscale
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
...@@ -73,7 +70,7 @@ def normalize_e4m3fn_to_e4m3fnuz( ...@@ -73,7 +70,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
def cutlass_block_fp8_supported() -> bool: def cutlass_block_fp8_supported() -> bool:
if os.environ.get("SUPPORT_CUTLASS_BLOCK_FP8") is None: if get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
return False return False
if _is_cuda: if _is_cuda:
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
......
...@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module):
quant_method = None quant_method = None
if quant_config is not None: if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix) quant_method = quant_config.get_quant_method(self, prefix=prefix)
print("quant_method", quant_method)
if quant_method is None: if quant_method is None:
quant_method = UnquantizedEmbeddingMethod() quant_method = UnquantizedEmbeddingMethod()
......
...@@ -69,7 +69,7 @@ class TestQwen2FP8(unittest.TestCase): ...@@ -69,7 +69,7 @@ class TestQwen2FP8(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
print(f"{metrics=}") print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.79) self.assertGreater(metrics["accuracy"], 0.78)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment