Unverified Commit 198974cd authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: support sm75 with FlashInfer v0.1.6 (#1233)

parent 6cc38b2b
...@@ -135,7 +135,7 @@ sky status --endpoint 30000 sglang ...@@ -135,7 +135,7 @@ sky status --endpoint 30000 sglang
### Common Notes ### Common Notes
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. If you are using NVIDIA GPU devices below sm80, such as T4, you can't use SGLang for the time being. We expect to resolve this issue soon, so please stay tuned. If you encounter any FlashInfer-related issues on sm80+ devices (e.g., A100, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise a issue. - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
## Backend: SGLang Runtime (SRT) ## Backend: SGLang Runtime (SRT)
......
...@@ -30,18 +30,11 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -30,18 +30,11 @@ from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
def __init__(self, **kwargs):
super().__init__()
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if self.is_lower_sm80:
return self.forward_native(x)
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,) output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
......
...@@ -32,15 +32,12 @@ class RMSNorm(CustomOp): ...@@ -32,15 +32,12 @@ class RMSNorm(CustomOp):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
def forward_cuda( def forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.is_lower_sm80:
return self.forward_native(x, residual)
if residual is not None: if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
......
...@@ -161,6 +161,8 @@ class ModelRunner: ...@@ -161,6 +161,8 @@ class ModelRunner:
"Compute capability below sm80. Use float16 due to lack of bfloat16 support." "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
) )
self.server_args.dtype = "float16" self.server_args.dtype = "float16"
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
monkey_patch_vllm_dummy_weight_loader() monkey_patch_vllm_dummy_weight_loader()
self.device_config = DeviceConfig() self.device_config = DeviceConfig()
......
...@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.1.5", "0.1.6",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment