"git@developer.sourcefind.cn:change/sglang.git" did not exist on "eb38c7d1cae1c616de3c1c0ce40353c720f7e3c7"
Unverified Commit 94752ac8 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: use FlashInfer rmsnorm and silu (#907)

parent 43fbb6d9
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from flashinfer.activation import silu_and_mul
class SiluAndMul(nn.Module):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(x, out)
return out
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
......@@ -23,8 +23,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
class InternLM2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
......@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
class InternLM2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
......@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
class InternLM2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......
......@@ -24,8 +24,6 @@ from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -39,6 +37,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......
......@@ -384,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.3",
"0.1.4",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
......
import itertools
import unittest
import torch
from sglang.srt.layers.layernorm import RMSNorm
class TestRMSNorm(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
ADD_RESIDUAL = [False, True]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed):
torch.manual_seed(seed)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None
with torch.inference_mode():
ref_out = layer.forward_native(x, residual)
out = layer(x, residual)
if add_residual:
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2))
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2))
def test_rms_norm(self):
for params in itertools.product(
self.NUM_TOKENS,
self.HIDDEN_SIZES,
self.ADD_RESIDUAL,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
hidden_size=params[1],
add_residual=params[2],
dtype=params[3],
seed=params[4],
):
self._run_rms_norm_test(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment