Unverified Commit e59ca942 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent a57a3044
...@@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict): ...@@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict):
num_stages: int num_stages: int
def benchmark_config( def benchmark_config(config: BenchmarkConfig,
config: BenchmarkConfig, num_tokens: int,
num_tokens: int, num_experts: int,
num_experts: int, shard_intermediate_size: int,
shard_intermediate_size: int, hidden_size: int,
hidden_size: int, topk: int,
topk: int, dtype: torch.dtype,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_fp8_w8a8: bool, use_int8_w8a16: bool,
use_int8_w8a16: bool, num_iters: int = 100,
num_iters: int = 100, block_quant_shape: List[int] = None,
block_quant_shape: List[int] = None, use_deep_gemm: bool = False) -> float:
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16: if use_int8_w8a16:
...@@ -115,22 +114,41 @@ def benchmark_config( ...@@ -115,22 +114,41 @@ def benchmark_config(
def run(): def run():
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe import override_config
with override_config(config): with override_config(config):
fused_moe( if use_deep_gemm:
x, topk_weights, topk_ids = fused_topk(x, input_gating, topk,
w1, False)
w2, return fused_experts(
input_gating, x,
topk, w1,
renormalize=True, w2,
inplace=True, topk_weights,
use_fp8_w8a8=use_fp8_w8a8, topk_ids,
use_int8_w8a16=use_int8_w8a16, inplace=True,
w1_scale=w1_scale, use_fp8_w8a8=use_fp8_w8a8,
w2_scale=w2_scale, w1_scale=w1_scale,
a1_scale=a1_scale, w2_scale=w2_scale,
a2_scale=a2_scale, a1_scale=a1_scale,
block_shape=block_quant_shape, a2_scale=a2_scale,
) block_shape=block_quant_shape,
allow_deep_gemm=True,
)
else:
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
)
# JIT compilation & warmup # JIT compilation & warmup
run() run()
...@@ -366,6 +384,7 @@ class BenchmarkWorker: ...@@ -366,6 +384,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
block_quant_shape: List[int] = None, block_quant_shape: List[int] = None,
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed) current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype, dtype_str = get_config_dtype_str(dtype,
...@@ -396,7 +415,8 @@ class BenchmarkWorker: ...@@ -396,7 +415,8 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=100, num_iters=100,
block_quant_shape=block_quant_shape) block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm)
return config, kernel_time return config, kernel_time
def tune( def tune(
...@@ -411,6 +431,7 @@ class BenchmarkWorker: ...@@ -411,6 +431,7 @@ class BenchmarkWorker:
use_int8_w8a16: bool, use_int8_w8a16: bool,
search_space: list[dict[str, int]], search_space: list[dict[str, int]],
block_quant_shape: list[int], block_quant_shape: list[int],
use_deep_gemm: bool,
) -> dict[str, int]: ) -> dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
...@@ -436,7 +457,8 @@ class BenchmarkWorker: ...@@ -436,7 +457,8 @@ class BenchmarkWorker:
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=20, num_iters=20,
block_quant_shape=block_quant_shape) block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm)
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile. # Some configurations may be invalid and fail to compile.
continue continue
...@@ -550,6 +572,8 @@ def main(args: argparse.Namespace): ...@@ -550,6 +572,8 @@ def main(args: argparse.Namespace):
else: else:
batch_sizes = [args.batch_size] batch_sizes = [args.batch_size]
use_deep_gemm = bool(args.use_deep_gemm)
ray.init() ray.init()
num_gpus = int(ray.available_resources()["GPU"]) num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
...@@ -572,10 +596,10 @@ def main(args: argparse.Namespace): ...@@ -572,10 +596,10 @@ def main(args: argparse.Namespace):
start = time.time() start = time.time()
configs = _distribute( configs = _distribute(
"tune", "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape) block_quant_shape, use_deep_gemm)
for batch_size in batch_sizes]) for batch_size in batch_sizes])
best_configs = { best_configs = {
M: sort_config(config) M: sort_config(config)
for M, config in zip(batch_sizes, configs) for M, config in zip(batch_sizes, configs)
...@@ -589,7 +613,7 @@ def main(args: argparse.Namespace): ...@@ -589,7 +613,7 @@ def main(args: argparse.Namespace):
outputs = _distribute( outputs = _distribute(
"benchmark", "benchmark",
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
use_fp8_w8a8, use_int8_w8a16, block_quant_shape) use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
for batch_size in batch_sizes]) for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
...@@ -611,6 +635,7 @@ if __name__ == "__main__": ...@@ -611,6 +635,7 @@ if __name__ == "__main__":
type=str, type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"], choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto") default="auto")
parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
......
...@@ -6,12 +6,22 @@ import itertools ...@@ -6,12 +6,22 @@ import itertools
import pytest import pytest
import torch import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
deep_gemm_moe_fp8, fused_topk, moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0): if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True) allow_module_level=True)
...@@ -21,17 +31,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] ...@@ -21,17 +31,18 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824] D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 256, 512] GROUP_SIZE = [64, 128, 256, 512]
M = [1, 7, 83, 512, 2048] M = [1, 7, 8, 83, 84, 512, 2048, 4096]
N = [128, 512, 1024, 4096, 7748, 13824] N = [128, 512, 1024, 4096, 7168, 7748, 13824]
K = [256, 4096, 5120, 3884, 13824] K = [256, 4096, 5120, 3884, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168. # and its hidden size is 7168.
M_moe = [1, 7, 83, 512, 2048] M_moe = [1, 2, 7, 83, 128, 512, 2048]
N_moe = [4608] # [128, 4608, 13824] M_moe_dg = [128, 192, 512, 1335, 2048]
K_moe = [7168] # [256, 7168, 13824] N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824]
BLOCK_SIZE = [[128, 128]] BLOCK_SIZE = [[128, 128]]
E = [8, 24] # [8, 24, 128, 256] E = [2, 8, 16, 24] # [128, 256]
TOP_KS = [2] # [1, 2, 6] TOP_KS = [1, 2, 6]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0] SEEDS = [0]
...@@ -217,11 +228,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -217,11 +228,16 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
SEEDS)) SEEDS))
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed) torch.manual_seed(seed)
factor_for_scale = 1e-2 factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
vllm_config = VllmConfig()
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = (torch.rand( w1_bf16 = (torch.rand(
...@@ -246,25 +262,240 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ...@@ -246,25 +262,240 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
out = fused_moe( # Set the context to avoid lots of warning spam.
a, with set_current_vllm_config(vllm_config):
w1, out = fused_moe(
w2, a,
score, w1,
topk, w2,
renormalize=False, score,
use_fp8_w8a8=True, topk,
w1_scale=w1_s, renormalize=False,
w2_scale=w2_s, use_fp8_w8a8=True,
block_shape=block_size, w1_scale=w1_s,
) w2_scale=w2_s,
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape=block_size,
block_size) )
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
print(f"{out.sum()=}") block_size)
print(f"{ref_out.sum()=}")
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
_, block_k = block_size[0], block_size[1]
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)
ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001
def fp8_perm(m, idx):
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
M, K = a.shape
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
num_tokens = topk * M
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
a = fp8_perm(a, sorted_token_ids // topk)
if a_s is not None:
a_s = a_s[sorted_token_ids // topk]
return a, a_s, m_indices, inv_perm
def test_moe_unpermute(out, inv_perm, topk, K, topk_weight):
M = topk_weight.shape[0]
out = out[inv_perm, ...]
tmp_out = out.view(-1, topk, K)
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
block_shape):
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
num_groups = w1.shape[0]
M, K = a.shape
N = w2.shape[-1]
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = per_token_group_quant_fp8(a, block_m)
a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids,
num_groups, topk, block_m)
inter_out = torch.zeros((a_q.shape[0], N * 2),
dtype=torch.bfloat16,
device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight)
return final_out
@pytest.mark.parametrize(
"M,N,K,E,topk,seed",
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
# only aligned sizes
if (N % block_m != 0 or K % block_m != 0 or topk > E):
pytest.skip(
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
if (N <= 512):
pytest.skip("Skipping N <= 512 until performance issues solved.")
vllm_config = VllmConfig()
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
score = torch.randn((M, E), dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w2 = (N + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(E):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
if M >= 128:
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
score, topk, block_size)
else:
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean( rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32)))) torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03 assert rel_diff < 0.03
...@@ -1224,7 +1224,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, ...@@ -1224,7 +1224,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies: torch.Tensor, token_expert_indicies: torch.Tensor,
gating_output: float) -> None: gating_output: torch.Tensor) -> None:
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
token_expert_indicies, gating_output) token_expert_indicies, gating_output)
......
...@@ -105,6 +105,7 @@ if TYPE_CHECKING: ...@@ -105,6 +105,7 @@ if TYPE_CHECKING:
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -687,6 +688,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -687,6 +688,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_BUCKET_PADDING_GAP": "VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] ...@@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__) logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
class Fp8Config(QuantizationConfig): class Fp8Config(QuantizationConfig):
"""Config class for FP8.""" """Config class for FP8."""
...@@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm:
logger.warning_once("Failed to import DeepGemm kernels.")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False) requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False) requires_grad=False)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
return return
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
...@@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment