Unverified Commit 65b7c9b7 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

cleanup deps 2/n (#4464)

parent 2c4f5cca
...@@ -23,7 +23,9 @@ import torch.nn.functional as F ...@@ -23,7 +23,9 @@ import torch.nn.functional as F
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
if is_cuda_available(): _is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
...@@ -165,7 +167,7 @@ def get_act_fn( ...@@ -165,7 +167,7 @@ def get_act_fn(
return act_fn return act_fn
if not is_cuda_available(): if not _is_cuda:
logger.info( logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
) )
......
...@@ -21,7 +21,9 @@ import torch.nn as nn ...@@ -21,7 +21,9 @@ import torch.nn as nn
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda_available
if is_cuda_available(): _is_cuda = is_cuda_available()
if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
fused_add_rmsnorm, fused_add_rmsnorm,
gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm,
...@@ -117,7 +119,7 @@ class GemmaRMSNorm(CustomOp): ...@@ -117,7 +119,7 @@ class GemmaRMSNorm(CustomOp):
return out return out
if not is_cuda_available(): if not _is_cuda:
logger.info( logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
) )
......
...@@ -41,9 +41,13 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict ...@@ -41,9 +41,13 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_hip from sglang.srt.utils import add_prefix, is_cuda, is_hip
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import awq_dequantize
class DeepseekModelNextN(nn.Module): class DeepseekModelNextN(nn.Module):
...@@ -261,14 +265,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -261,14 +265,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self_attn = self.model.decoder.self_attn self_attn = self.model.decoder.self_attn
if hasattr(self_attn.kv_b_proj, "qweight"): if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible # AWQ compatible
w = ops.awq_dequantize( if _is_cuda:
self_attn.kv_b_proj.qweight, w = awq_dequantize(
self_attn.kv_b_proj.scales, self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.scales,
0, self_attn.kv_b_proj.qzeros,
0, ).T
0, else:
).T w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else: else:
w = self_attn.kv_b_proj.weight w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
......
...@@ -68,12 +68,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -68,12 +68,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
if is_cuda_available(): if _is_cuda:
from sgl_kernel import bmm_fp8 from sgl_kernel import awq_dequantize, bmm_fp8
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -1174,14 +1175,21 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1174,14 +1175,21 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn = self.model.layers[layer_id].self_attn self_attn = self.model.layers[layer_id].self_attn
if hasattr(self_attn.kv_b_proj, "qweight"): if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible # AWQ compatible
w = ops.awq_dequantize( if _is_cuda:
self_attn.kv_b_proj.qweight, w = awq_dequantize(
self_attn.kv_b_proj.scales, self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.scales,
0, self_attn.kv_b_proj.qzeros,
0, ).T
0, else:
).T w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else: else:
w = self_attn.kv_b_proj.weight w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment