Commit e00b0a19 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.3

parents ead94d93 3f1166ab
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
}
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}
}
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,
and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
fused_moe_kernel[grid](
A,
B,
C,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
**config,
)
@functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs directory
device_name = torch.cuda.get_device_name().replace(" ", "_")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs",
f"E={E},N={N},device_name={device_name}.json")
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
f"Using configuration from {config_file_path} for MoE layer.")
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default configuration
return None
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2])
if configs:
# If an optimal configuration map has been found, look up the optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, False,
topk_ids.shape[1], config)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
topk_weights, topk_ids, sorted_token_ids,
expert_ids, num_tokens_post_padded, True, 1,
config)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
...@@ -17,6 +17,14 @@ from vllm.logger import init_logger ...@@ -17,6 +17,14 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None:
return shard_size, shard_offset
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(ABC): class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
...@@ -54,7 +62,6 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -54,7 +62,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
params_dtype: torch.dtype) -> Dict[str, Any]: params_dtype: torch.dtype) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition, weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
...@@ -113,9 +120,7 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -113,9 +120,7 @@ class ReplicatedLinear(torch.nn.Module):
self.register_parameter(name, weight) self.register_parameter(name, weight)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, torch.empty(self.output_size, dtype=self.params_dtype))
device=torch.cuda.current_device(),
dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0}) set_weight_attrs(self.bias, {"output_dim": 0})
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
...@@ -183,7 +188,6 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -183,7 +188,6 @@ class ColumnParallelLinear(torch.nn.Module):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"output_dim": 0, "output_dim": 0,
...@@ -280,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -280,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow( loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size) output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id) self.weight_loader(param, loaded_weight_shard, shard_id)
...@@ -297,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -297,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
...@@ -376,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -376,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already packed. # Loaded weight is already packed.
if output_dim is None: if output_dim is None:
...@@ -397,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -397,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow( loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size) output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id) self.weight_loader(param, loaded_weight_shard, shard_id)
...@@ -421,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -421,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
...@@ -509,9 +534,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -509,9 +534,7 @@ class RowParallelLinear(torch.nn.Module):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, torch.empty(self.output_size, dtype=params_dtype))
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"output_dim": 0, "output_dim": 0,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
......
...@@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY = { _QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig, "awq": AWQConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
} }
......
...@@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition, input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -148,12 +145,21 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -148,12 +145,21 @@ class AWQLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = weights["qweight"]
qzeros = weights["qzeros"]
scales = weights["scales"] scales = weights["scales"]
qzeros = weights["qzeros"]
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.reshape(out_shape) return out.reshape(out_shape)
import enum import enum
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fractions import Fraction
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig): ...@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.desc_act = desc_act self.desc_act = desc_act
self.pack_factor = 32 // self.weight_bits self.pack_factor = Fraction(32, self.weight_bits)
# exllama kernel v1 only supports 4 bit if self.weight_bits not in [2, 3, 4, 8]:
if self.weight_bits != 4:
raise ValueError( raise ValueError(
"Currently, only 4-bit weight quantization is supported for " "Currently, only 2/3/4/8-bit weight quantization is supported for "
f"GPTQ, but got {self.weight_bits} bits.") f"GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
...@@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase):
i // self.quant_config.group_size i // self.quant_config.group_size
for i in range(input_size_per_partition) for i in range(input_size_per_partition)
], ],
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
scale_and_zero_size, scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
scale_and_zero_size, scale_and_zero_size,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -205,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -205,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else: else:
weights["g_idx"] = torch.empty((1, 1), device="meta") weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"], output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"], weights["qzeros"], weights["scales"],
weights["g_idx"], weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY) weights["exllama_state"] == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.reshape(out_shape) return output.reshape(out_shape)
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def __init__(
self,
group_size: int,
) -> None:
# Group size for the quantization.
self.group_size = group_size
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) is supported for "
f"Marlin, but got group_size of {self.group_size}")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = 64
# Min in_features dim
self.min_k_threads = 128
# Max parallel problems to solve at once (improves large batch performance)
self.max_parallel = 16
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size}"
@classmethod
def get_name(cls) -> str:
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
def get_linear_method(self) -> "MarlinLinearMethod":
return MarlinLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def __init__(self, quant_config: MarlinConfig):
self.quant_config = quant_config
def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
del output_size # Unused.
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
)
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
)
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
# Determine if channelwise or not
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
scales = Parameter(
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
return {
"B": qweight,
"s": scales,
"workspace": workspace,
}
def apply_weights(
self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = weights["B"]
scales = weights["s"]
workspace = weights["workspace"]
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
...@@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
output_size, output_size,
self.quant_config.weight_bits**2, self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip(): if is_hip():
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float) out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16) out = out_f.to(dtype=torch.float16)
else: else:
# NOTE: The output tensor should be zero-initialized. # NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None: if bias is not None:
......
...@@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module): ...@@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module):
# create the cache on GPU for faster initialization. This may cause # create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours. # a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange( inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
self.rotary_dim))
return inv_freq return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache.""" """Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base) inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, t = torch.arange(self.max_position_embeddings, dtype=torch.float)
dtype=torch.float,
device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
...@@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
# Thus, the maximum length after applying the rope scaling is # Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor. # self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda") t = torch.arange(max_len, dtype=torch.float)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
...@@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
(self.scaling_factor - 1))**(self.rotary_dim / (self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2)) (self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base) inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda") t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
...@@ -248,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int, ...@@ -248,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int,
def _yarn_linear_ramp_mask(low: float, high: float, dim: int, def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype, dtype: torch.dtype) -> torch.Tensor:
device: torch.device) -> torch.Tensor:
if low == high: if low == high:
high += 0.001 # Prevent singularity high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype, device=device) - linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
...@@ -297,9 +292,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -297,9 +292,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style) is_neox_style)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange( pos_freqs = self.base**(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim) self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
...@@ -308,8 +303,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -308,8 +303,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.max_position_embeddings) self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation # Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask( inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float, low, high, self.rotary_dim // 2,
device="cuda")) * self.extrapolation_factor dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * ( inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq return inv_freq
...@@ -317,7 +312,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -317,7 +312,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor) inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor, t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32) dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale) cos = (freqs.cos() * self.mscale)
...@@ -360,7 +354,6 @@ def get_rope( ...@@ -360,7 +354,6 @@ def get_rope(
elif scaling_type == "yarn": elif scaling_type == "yarn":
original_max_position = rope_scaling[ original_max_position = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = { extra_kwargs = {
k: v k: v
for k, v in rope_scaling.items() for k, v in rope_scaling.items()
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens ...@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput) SequenceData, SequenceGroupOutput, SequenceOutput)
from vllm.utils import is_neuron
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -27,9 +28,27 @@ class Sampler(nn.Module): ...@@ -27,9 +28,27 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.). parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
""" """
def __init__(self, vocab_size: int) -> None: def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward( def forward(
self, self,
...@@ -39,11 +58,14 @@ class Sampler(nn.Module): ...@@ -39,11 +58,14 @@ class Sampler(nn.Module):
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias, logits = self._get_logits(hidden_states, embedding, embedding_bias)
self.vocab_size)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because # Note: `_get_logits` is still distributed across TP workers because
...@@ -98,20 +120,6 @@ class Sampler(nn.Module): ...@@ -98,20 +120,6 @@ class Sampler(nn.Module):
prompt_logprobs, sample_logprobs) prompt_logprobs, sample_logprobs)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states( def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
...@@ -341,7 +349,9 @@ def _beam_search_sample( ...@@ -341,7 +349,9 @@ def _beam_search_sample(
def _multinomial( def _multinomial(
probs: torch.Tensor, probs: torch.Tensor,
num_samples: int, num_samples: int,
): seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also # This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync). # forces a GPU<->CPU sync).
...@@ -351,7 +361,15 @@ def _multinomial( ...@@ -351,7 +361,15 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view( probs.shape[1]).contiguous().view(
-1, probs.shape[1]) -1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1) q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators):
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples) return probs.div_(q).argmax(dim=1).view(-1, num_samples)
...@@ -369,6 +387,7 @@ def _sample( ...@@ -369,6 +387,7 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {} sample_metadata = {}
multinomial_samples = {}
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync. # The first loop can run without waiting on GPU<->CPU sync.
...@@ -383,15 +402,20 @@ def _sample( ...@@ -383,15 +402,20 @@ def _sample(
sample_metadata[sampling_type] = (seq_group_ids, seq_groups, sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices) is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) greedy_samples = torch.argmax(logprobs[sample_indices.long()],
elif sampling_type == SamplingType.RANDOM: dim=-1)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1 max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt: if is_prompt:
_, sampling_params = seq_group _, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of) max_best_of = max(max_best_of, sampling_params.best_of)
multinomial_samples = _multinomial(probs[sample_indices], seeded_args = {} if sampling_type == SamplingType.RANDOM else {
max_best_of) "seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices.long()], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices] beam_search_logprobs = logprobs[sample_indices]
else: else:
...@@ -406,9 +430,9 @@ def _sample( ...@@ -406,9 +430,9 @@ def _sample(
sampling_type] sampling_type]
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples) sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type == SamplingType.RANDOM: elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts, sample_results = _random_sample(seq_groups, is_prompts,
multinomial_samples) multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data, sampling_metadata.seq_data,
......
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch
import triton
import triton.language as tl
if triton.__version__ >= "2.1.0":
@triton.jit
def _fwd_kernel(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, allow_tf32=False)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None):
cap = torch.cuda.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
alibi_slopes,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
...@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import ( ...@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value.""" """Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to return ((vocab_size + pad_to - 1) // pad_to) * pad_to
...@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
num_embeddings: vocabulary size. num_embeddings: vocabulary size.
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
params_dtype: type of the parameters. params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
params_dtype: Optional[torch.dtype] = None): params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__() super().__init__()
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings) self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
...@@ -68,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -68,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
self.embedding_dim, self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.weight, { set_weight_attrs(self.weight, {
"parallel_dim": 0, "parallel_dim": 0,
...@@ -77,7 +85,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -77,7 +85,7 @@ class VocabParallelEmbedding(torch.nn.Module):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self. loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index] vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
...@@ -114,18 +122,22 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -114,18 +122,22 @@ class ParallelLMHead(VocabParallelEmbedding):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
bias: whether to use bias. bias: whether to use bias.
params_dtype: type of the parameters. params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
bias: bool = False, bias: bool = False,
params_dtype: Optional[torch.dtype] = None): params_dtype: Optional[torch.dtype] = None,
super().__init__(num_embeddings, embedding_dim, params_dtype) org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"parallel_dim": 0, "parallel_dim": 0,
......
...@@ -4,9 +4,8 @@ from typing import Type ...@@ -4,9 +4,8 @@ from typing import Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
...@@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype): ...@@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype) torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures: for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch) model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None: if model_cls is not None:
...@@ -32,16 +37,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -32,16 +37,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
f"Supported architectures: {ModelRegistry.get_supported_archs()}") f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig, device_config: DeviceConfig,
model_class = _get_model_architecture(model_config.hf_config) **kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
model_class = _get_model_architecture(model_config)
# Get the (maybe quantized) linear method. # Get the (maybe quantized) linear method.
linear_method = None linear_method = None
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config.quantization, quant_config = get_quant_config(model_config)
model_config.model,
model_config.hf_config,
model_config.download_dir)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability(): if capability < quant_config.get_min_capability():
...@@ -61,8 +65,18 @@ def get_model(model_config: ModelConfig) -> nn.Module: ...@@ -61,8 +65,18 @@ def get_model(model_config: ModelConfig) -> nn.Module:
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
with torch.device("cuda"): with torch.device(device_config.device):
model = model_class(model_config.hf_config, linear_method) if hasattr(model_class, "supported_lora_modules"):
model = model_class(model_config.hf_config, linear_method,
lora_config)
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
else:
model = model_class(model_config.hf_config, linear_method)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
......
...@@ -4,39 +4,48 @@ from typing import List, Optional, Type ...@@ -4,39 +4,48 @@ from typing import List, Optional, Type
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip, is_neuron
logger = init_logger(__name__) logger = init_logger(__name__)
# Architecture -> (module, class). # Architecture -> (module, class).
_MODELS = { _MODELS = {
"AquilaModel": ("aquila", "AquilaForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case # transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
} }
# Models not supported by ROCm. # Models not supported by ROCm.
...@@ -45,12 +54,17 @@ _ROCM_UNSUPPORTED_MODELS = [] ...@@ -45,12 +54,17 @@ _ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm. # Models partially supported by ROCm.
# Architecture -> Reason. # Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = { _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MistralForCausalLM": "MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention", "Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM": "MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention", "Sliding window attention is not yet supported in ROCm's flash attention",
} }
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
class ModelRegistry: class ModelRegistry:
...@@ -67,8 +81,15 @@ class ModelRegistry: ...@@ -67,8 +81,15 @@ class ModelRegistry:
logger.warning( logger.warning(
f"Model architecture {model_arch} is partially supported " f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
elif is_neuron():
if model_arch not in _NEURON_SUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"Neuron for now.")
module_name, model_cls_name = _MODELS[model_arch] module_name, model_cls_name = _MODELS[model_arch]
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
module = importlib.import_module( module = importlib.import_module(
f"vllm.model_executor.models.{module_name}") f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None) return getattr(module, model_cls_name, None)
......
...@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple ...@@ -23,6 +23,7 @@ from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -42,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -42,7 +43,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -186,7 +186,7 @@ class BaiChuanAttention(nn.Module): ...@@ -186,7 +186,7 @@ class BaiChuanAttention(nn.Module):
class BaiChuanDecoderLayer(nn.Module): class BaiChuanDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: BaiChuanConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
...@@ -245,7 +245,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -245,7 +245,7 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module): class BaiChuanModel(nn.Module):
def __init__(self, def __init__(self,
config: BaiChuanConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
......
...@@ -28,6 +28,7 @@ from typing import Optional ...@@ -28,6 +28,7 @@ from typing import Optional
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
...@@ -40,7 +41,7 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -40,7 +41,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
Based on the llama executor. Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention. The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overriden with a value The constant number of GQA heads in the decoder is overridden with a value
per layer. per layer.
Usually, in the HuggingFace implementation, instead of Usually, in the HuggingFace implementation, instead of
...@@ -56,10 +57,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -56,10 +57,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
self, self,
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, linear_method=linear_method) super().__init__(config=config,
linear_method=linear_method,
lora_config=lora_config)
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
ReplicatedLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class DeepseekMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}.")
self.experts = nn.ModuleList([
DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
self.pack_params()
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
linear_method=None)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
reduce_results=False,
)
def pack_params(self):
w1 = []
w2 = []
for expert in self.experts:
w1.append(expert.gate_up_proj.weight)
w2.append(expert.down_proj.weight)
self.w1 = torch._utils._flatten_dense_tensors(w1)
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
for data, param in zip(w1s, w1):
param.data = data
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
self.w2 = torch._utils._flatten_dense_tensors(w2)
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
for data, param in zip(w2s, w2):
param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(batch_size, sequence_length,
hidden_dim)
class DeepseekAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = DeepseekAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
if (config.n_routed_experts is not None and \
layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class DeepseekModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config,
layer_idx,
linear_method=linear_method)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = DeepseekModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment