Unverified Commit 65b7c9b7 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

cleanup deps 2/n (#4464)

parent 2c4f5cca
......@@ -23,7 +23,9 @@ import torch.nn.functional as F
from sglang.srt.utils import is_cuda_available
if is_cuda_available():
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from sglang.srt.custom_op import CustomOp
......@@ -165,7 +167,7 @@ def get_act_fn(
return act_fn
if not is_cuda_available():
if not _is_cuda:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
......
......@@ -21,7 +21,9 @@ import torch.nn as nn
from sglang.srt.utils import is_cuda_available
if is_cuda_available():
_is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
......@@ -117,7 +119,7 @@ class GemmaRMSNorm(CustomOp):
return out
if not is_cuda_available():
if not _is_cuda:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
......
......@@ -41,9 +41,13 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_hip
from sglang.srt.utils import add_prefix, is_cuda, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize
class DeepseekModelNextN(nn.Module):
......@@ -261,14 +265,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self_attn = self.model.decoder.self_attn
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
if _is_cuda:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
).T
else:
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
......
......@@ -68,12 +68,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
_is_hip = is_hip()
_is_cuda = is_cuda()
if is_cuda_available():
from sgl_kernel import bmm_fp8
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8
class DeepseekV2MLP(nn.Module):
......@@ -1174,14 +1175,21 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn = self.model.layers[layer_id].self_attn
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
if _is_cuda:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
).T
else:
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment