Unverified Commit fa46e2bd authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support offloading in fp8 (#9948)

parent b047b553
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import (
......@@ -31,7 +33,15 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.offloader import get_offloader
from sglang.srt.utils import (
ceil_div,
dispose_tensor,
get_bool_env_var,
is_cuda,
is_hip,
is_npu,
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
......@@ -535,6 +545,24 @@ class DeepEPMoE(EPMoE):
N = self.w13_weight.size(1)
scale_block_size = 128
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
)
hidden_states_fp8_shape = hidden_states_fp8.shape
hidden_states_fp8_device = hidden_states_fp8.device
hidden_states_fp8_dtype = hidden_states_fp8.dtype
......@@ -565,12 +593,17 @@ class DeepEPMoE(EPMoE):
)
output_index = torch.empty_like(topk_idx)
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
......@@ -595,7 +628,7 @@ class DeepEPMoE(EPMoE):
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
input_tensor, w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
down_input = torch.empty(
......@@ -625,7 +658,7 @@ class DeepEPMoE(EPMoE):
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
w2_weight_fp8,
down_output,
m_indices,
)
......@@ -885,3 +918,12 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
if get_moe_expert_parallel_world_size() > 1:
return EPMoE
return FusedMoE
def copy_list_to_gpu_no_ce(arr: List[int]):
from sgl_kernel.elementwise import copy_to_gpu_no_ce
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
return tensor_gpu
......@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from sglang.srt import offloader
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
......@@ -417,10 +418,14 @@ def block_quant_dequant(
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
assert isinstance(weight, torch.nn.Parameter)
assert isinstance(weight_scale_inv, torch.nn.Parameter)
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
weight, weight_scale_inv, weight_block_size
new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
)
offloader.update_param(weight, new_weight)
weight_scale_inv.data = new_weight_scale_inv
def _requant_weight_ue8m0(
weight: torch.Tensor,
......
......@@ -2244,8 +2244,15 @@ class DeepseekV2Model(nn.Module):
[
"w13_weight",
"w2_weight",
"w13_blockscale_swizzled",
"w2_blockscale_swizzled",
# only for nvfp4
*(
[
"w13_blockscale_swizzled",
"w2_blockscale_swizzled",
]
if hasattr(module, "w13_blockscale_swizzled")
else []
),
]
if isinstance(module, FusedMoE)
else []
......
......@@ -38,6 +38,10 @@ class BaseOffloader(ABC):
def post_init(self):
pass
@property
def forbid_copy_engine_usage(self):
return False
class NoopOffloader(BaseOffloader):
pass
......@@ -233,6 +237,10 @@ class OffloaderV2(BaseOffloader):
for i in range(self.prefetch_step):
self.offloaders[i].start_onload()
@property
def forbid_copy_engine_usage(self):
return self.mode == "cpu"
def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
def _on_forward_end():
......@@ -398,14 +406,30 @@ class _ShmCpuParamOffloader(_BaseParamOffloader):
return self.shm_cpu_data.to("cuda", non_blocking=True)
def update_param(param, new_tensor):
"""Update parameter while keeping properties needed by Offloader (e.g. pinned host memory)."""
if param.device == new_tensor.device:
param.data = new_tensor
else:
assert param.device == torch.device(
"cpu"
), f"{param.device=} {new_tensor.device=}"
param.data = _create_cpu_data(new_tensor, pin_memory=True)
def _move_param_to_cpu(param, pin_memory: bool):
param.data = _create_cpu_data(param.data, pin_memory=pin_memory)
def _create_cpu_data(data, pin_memory: bool):
cpu_data = _empty_strided_like(
param.data,
data,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(param.data)
param.data = cpu_data
cpu_data.copy_(data)
return cpu_data
def _move_param_to_meta(module, param_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment