Unverified Commit a39d9287 authored by Yijie Zhu's avatar Yijie Zhu Committed by GitHub
Browse files

support qwen2 running on ascend npu device (#7022)


Co-authored-by: default avatar刁莹煜 <diaoyingyu1@hisilicon.com>
parent 10d60cd4
......@@ -51,6 +51,10 @@
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0\"\n",
")\n",
"## To run qwen2.5-0.5b-instruct model on the Ascend-Npu, you can execute the following command:\n",
"# server_process, port = launch_server_cmd(\n",
"# \"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0 --device npu --tp 2 --attention-backend torch_native\"\n",
"# )\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
......
......@@ -4,7 +4,7 @@ from typing import List, Tuple
import torch
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
......@@ -25,7 +25,7 @@ if not is_hpu():
logger.warning("Failed to import from custom_ar with %r", e)
if not is_hip():
if not is_hip() and not is_npu():
if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar
else:
......
......@@ -29,10 +29,11 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
from sglang.utils import resolve_obj_by_qualname
_is_cuda = is_cuda()
_is_npu = is_npu()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
......@@ -184,7 +185,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity()
if not _is_cuda:
if not _is_cuda and not _is_npu:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
......
......@@ -20,10 +20,11 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
......@@ -187,7 +188,7 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not (_is_cuda or _is_hip):
if not (_is_cuda or _is_hip or _is_npu):
logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
)
......
......@@ -17,11 +17,12 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize,
replace_parameter,
)
from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
_is_cuda = is_cuda()
_is_npu = is_npu()
if not _is_cuda:
if not _is_cuda and not _is_npu:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -67,6 +67,7 @@ from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_hip,
is_npu,
log_info_on_rank0,
print_warning_once,
set_weight_attrs,
......@@ -74,6 +75,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_npu = is_npu()
_is_fp8_fnuz = is_fp8_fnuz()
......@@ -86,7 +88,7 @@ if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
if not _is_cuda:
if not _is_cuda and not _is_npu:
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -6,11 +6,12 @@ from typing import List, Mapping, Tuple, Union
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_npu
_is_cuda = is_cuda()
_is_npu = is_npu()
if not _is_cuda:
if not _is_cuda and not _is_npu:
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -8,10 +8,11 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import is_cuda, is_hip, is_npu
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
......@@ -84,7 +85,7 @@ class RotaryEmbedding(CustomOp):
if not _is_cuda:
cache = cache.to(dtype)
if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
if not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]:
from vllm._custom_ops import rotary_embedding
self.vllm_rotary_embedding = rotary_embedding
......
......@@ -1291,6 +1291,15 @@ def get_hpu_memory_capacity():
)
def get_npu_memory_capacity():
try:
import torch_npu
return torch.npu.mem_get_info()[1] // 1024 // 1024 # unit: MB
except ImportError as e:
raise ImportError("torch_npu is required when run on npu device.")
def get_device_memory_capacity(device: str = None):
if is_cuda():
gpu_mem = get_nvgpu_memory_capacity()
......@@ -1298,6 +1307,8 @@ def get_device_memory_capacity(device: str = None):
gpu_mem = get_amdgpu_memory_capacity()
elif device == "hpu":
gpu_mem = get_hpu_memory_capacity()
elif device == "npu":
gpu_mem = get_npu_memory_capacity()
else:
# GPU memory is not known yet or no GPU is available.
gpu_mem = None
......@@ -1423,6 +1434,11 @@ def get_device(device_id: Optional[int] = None) -> str:
return "xpu"
return "xpu:{}".format(device_id)
if hasattr(torch, "npu") and torch.npu.is_available():
if device_id == None:
return "npu"
return "npu:{}".format(device_id)
if is_habana_available():
try:
import habana_frameworks.torch.hpu
......@@ -1497,15 +1513,35 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return major, minor
def get_npu_compiler_config():
config = {
"frozen_parameter": True,
"tiling_schedule_optimize": True,
"topology_sorting_strategy": "StableRDFS",
}
return config
def get_compiler_backend() -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return "hpu_backend"
if hasattr(torch, "npu") and torch.npu.is_available():
import torchair
try:
import torchair
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
from torchair.configs.compiler_config import CompilerConfig
except ImportError as e:
raise ImportError(
"NPU detected, but torchair package is not installed. "
"Please install torchair for torch.compile support on NPU."
)
compiler_config = CompilerConfig()
predefined_config = get_npu_compiler_config()
for k, v in predefined_config.items():
setattr(compiler_config.experimental_config, k, v)
config = torchair.CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
npu_backend = torchair.get_npu_backend(compiler_config=compiler_config)
return npu_backend
return "inductor"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment