"vscode:/vscode.git/clone" did not exist on "f95da13c3da5e5ff74e5f5f1da109c5e1cae0886"
Commit 7c4f76e3 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.0

parents 2da0dd3e 51c31bc1
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
{ {
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1": {
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "BLOCK_SIZE_M": 16,
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "BLOCK_SIZE_N": 64,
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4}, "BLOCK_SIZE_K": 128,
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "GROUP_SIZE_M": 64,
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "num_warps": 4,
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "num_stages": 4
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, },
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "2": {
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "BLOCK_SIZE_M": 16,
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, "BLOCK_SIZE_N": 128,
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, "BLOCK_SIZE_K": 128,
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, "GROUP_SIZE_M": 32,
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, "num_warps": 8,
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, "num_stages": 4
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, },
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, "4": {
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, "BLOCK_SIZE_M": 16,
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, "BLOCK_SIZE_N": 32,
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, "BLOCK_SIZE_K": 256,
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, "GROUP_SIZE_M": 16,
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4} "num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
} }
...@@ -30,9 +30,10 @@ def fused_moe_kernel( ...@@ -30,9 +30,10 @@ def fused_moe_kernel(
K, K,
EM, EM,
num_valid_tokens, num_valid_tokens,
# The stride variables represent how much to increase the ptr by when moving by 1 # The stride variables represent how much to increase the ptr by when
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # moving by 1 element in a particular dimension. E.g. `stride_am` is
# by to get the element one row down (A has M rows). # how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am, stride_am,
stride_ak, stride_ak,
stride_be, stride_be,
...@@ -50,17 +51,30 @@ def fused_moe_kernel( ...@@ -50,17 +51,30 @@ def fused_moe_kernel(
compute_type: tl.constexpr, compute_type: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters: Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - A: The input tensor representing tokens with shape (*, K), where '*' can
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. be any shape representing batches and K is the feature dimension of
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, each token.
and N is the output feature dimension. - B: The stacked MOE weight tensor with shape (E, N, K), where E is
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. the number of experts, K is the input feature dimension, and N is
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. the output feature dimension.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` - C: The output cache tensor with shape (M, topk, N), where M is the
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
""" """
# ----------------------------------------------------------- # -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute. # Map program ids `pid` to the block of C it should compute.
...@@ -105,7 +119,8 @@ def fused_moe_kernel( ...@@ -105,7 +119,8 @@ def fused_moe_kernel(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension. # Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs, a = tl.load(a_ptrs,
mask=token_mask[:, None] & mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K), (offs_k[None, :] < K - k * BLOCK_SIZE_K),
...@@ -139,30 +154,41 @@ def moe_align_block_size( ...@@ -139,30 +154,41 @@ def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters: Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication. - block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts. - num_experts: The total number of experts.
Returns: Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block. - expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. - num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. This function pads the number of tokens that each expert needs to process
Padding ensures that during block matrix multiplication, the dimensions align correctly. so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example: Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert. - As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block. - Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. - After sorting by expert index, we obtain token_ids
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
""" """
sorted_ids = torch.empty( sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1), ), (topk_ids.numel() + num_experts * (block_size - 1), ),
...@@ -219,23 +245,28 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -219,23 +245,28 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
) )
def get_config_file_name(E: int, N: int) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
return f"E={E},N={N},device_name={device_name}.json"
@functools.lru_cache @functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes The return value will be a dictionary that maps an irregular grid of
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch batch sizes to configurations of the fused_moe kernel. To evaluate the
size bs, the closest batch size in the grid should be picked and the associated kernel on a given batch size bs, the closest batch size in the grid should
configuration chosen to invoke the kernel. be picked and the associated configuration chosen to invoke the kernel.
""" """
# First look up if an optimized configuration is available in the configs directory # First look up if an optimized configuration is available in the configs
device_name = torch.cuda.get_device_name().replace(" ", "_") # directory
json_file_name = get_config_file_name(E, N)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
f"E={E},N={N},device_name={device_name}.json")
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info( logger.info(
...@@ -243,7 +274,8 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: ...@@ -243,7 +274,8 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default configuration # If no optimized configuration is available, we will use the default
# configuration
return None return None
...@@ -258,18 +290,22 @@ def fused_moe( ...@@ -258,18 +290,22 @@ def fused_moe(
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters: Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation (before softmax). - gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False. - inplace (bool): If True, perform the operation in-place.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
...@@ -325,7 +361,8 @@ def fused_moe( ...@@ -325,7 +361,8 @@ def fused_moe(
configs = get_moe_configs(E, w2.shape[2]) configs = get_moe_configs(E, w2.shape[2])
if configs: if configs:
# If an optimal configuration map has been found, look up the optimal config # If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Else use the default config # Else use the default config
......
...@@ -5,14 +5,14 @@ import torch ...@@ -5,14 +5,14 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.parallel_utils.utils import ( from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim) divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -73,7 +73,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -73,7 +73,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"] weight = weights["weight"]
if self.separate_bias_add: if self.separate_bias_add:
if bias: if bias is not None:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
return F.linear(x, weight) return F.linear(x, weight)
return F.linear(x, weight, bias) return F.linear(x, weight, bias)
...@@ -285,7 +285,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -285,7 +285,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -307,7 +308,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -307,7 +308,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -413,7 +415,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -413,7 +415,8 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
...@@ -442,7 +445,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -442,7 +445,8 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling. # If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
......
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0,
logits_as_input: bool = False) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits
from typing import Optional, Union
import torch
import triton
import triton.language as tl
def seeded_uniform(
*size,
seeds: torch.Tensor,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
pin_memory: Optional[bool] = False,
) -> torch.Tensor:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims = len(size)
if n_dims > 3:
raise ValueError("seeded_uniform only supports up to 3D tensors")
if out is None:
out = torch.empty(*size,
dtype=dtype,
device=device,
pin_memory=pin_memory)
elif out.shape != size:
raise ValueError("shape of out and size must be the same")
if n_dims == 3:
n_rows, n_3d, n_cols = out.shape
stride_row = out.stride(0)
stride_3d = out.stride(1)
elif n_dims == 2:
n_rows, n_cols = out.shape
n_3d = 1
stride_row = out.stride(0)
stride_3d = 1
else:
n_cols = out.shape[0]
n_rows = 1
n_3d = 1
stride_row = 1
stride_3d = 1
if seeds.ndim != 1:
raise ValueError("seeds must be a 1D tensor")
if seeds.numel() != n_rows:
raise ValueError(
"seeds must have the same number of elements as out has rows")
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size = triton.next_power_of_2(n_cols)
philox_block_size = max(full_block_size // 4, 1)
n_slices = full_block_size // philox_block_size
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
num_warps = 32
elif philox_block_size >= 4096:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
_seeded_uniform_triton[(n_rows, n_3d)](
out,
seeds,
stride_row,
stride_3d,
seeds.stride(0),
n_rows,
n_3d,
n_cols,
n_slices=n_slices,
num_warps=num_warps,
block_size=philox_block_size,
)
return out
@triton.jit
def _seeded_uniform_triton(
out_ptr: torch.Tensor,
seed_ptr: torch.Tensor,
out_row_stride: int,
out_3d_stride: int,
seed_row_stride: int,
n_rows: int,
n_3d: int,
n_cols: int,
n_slices: tl.constexpr,
block_size: tl.constexpr,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
# Get the row index.
row_idx = tl.program_id(axis=0)
three_d_idx = tl.program_id(axis=1)
philox_offsets = tl.arange(0, block_size)
# Get the seed for the current element.
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
if three_d_idx > 0:
seed ^= three_d_idx
# Generate random numbers in [0, 1).
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
three_d_idx * out_3d_stride)
out1_offsets = philox_offsets
tl.store(output_row_start_ptr + out1_offsets,
out1,
mask=out1_offsets < n_cols)
if n_slices > 1:
out2_offsets = tl.arange(block_size, block_size * 2)
tl.store(output_row_start_ptr + out2_offsets,
out2,
mask=out2_offsets < n_cols)
if n_slices > 2:
out3_offsets = tl.arange(block_size * 2, block_size * 3)
tl.store(output_row_start_ptr + out3_offsets,
out3,
mask=out3_offsets < n_cols)
if n_slices > 3:
out4_offsets = tl.arange(block_size * 3, block_size * 4)
tl.store(output_row_start_ptr + out4_offsets,
out4,
mask=out4_offsets < n_cols)
import math
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
_EPS = 1e-6
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)
def _multi_split_sample(
probs: torch.Tensor,
seeds: torch.Tensor,
n_splits: int,
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
*,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
"""Sample tokens where vocab size is split into multiple parts
(too large for Triton otherwise)."""
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
split_probs = probs.tensor_split(n_splits, 1)
split_logprobs = logprobs.tensor_split(n_splits, 1)
sampled_tokens_tmp = [
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
for _ in range(n_splits)
]
sampled_logprobs_tmp = [
torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
# We are purposefuly using sampled_tokens_size as we need to always
# save modified probs in this case.
sampled_modified_probs_tmp = [
torch.empty(sampled_tokens_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
for i in range(n_splits):
n_samples = sample_indices.shape[0]
n_cols = split_probs[i].shape[1]
n_best = sampled_tokens_tmp[i].shape[1]
uniform_noise = seeded_uniform(n_samples,
n_best,
n_cols,
seeds=seeds[i].flatten(),
device=split_probs[i].device,
dtype=split_probs[i].dtype)
# TODO(yard1): See if we can remove the contiguous() calls.
# Will need kernel support.
_sample(
split_probs[i].contiguous(),
split_logprobs[i].contiguous(),
sample_indices,
sampled_tokens_tmp[i],
sampled_logprobs_tmp[i],
sampled_modified_probs_tmp[i],
seeds[i],
uniform_noise,
modify_greedy_probs=False,
save_logprobs=save_logprobs,
save_modified_probs=True,
)
if i > 0:
# Add offset to sampled tokens
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
sampled_tokens = torch.stack(sampled_tokens_tmp)
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
# Reduce the results from the splits.
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
dim=0,
keepdim=True)
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
if save_logprobs:
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
else:
sampled_logprobs = None
sampled_modified_probs = sampled_modified_probs.squeeze(0)
if modify_greedy_probs:
# We need to modify the greedy probs for the sampled tokens.
# We can't do this in the kernel as we need to know the
# sampled tokens.
probs.fill_(0.0)
probs.scatter_(1, sampled_tokens, 1.0)
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
def sample(
probs: torch.Tensor,
seeds: torch.Tensor,
*,
max_best_of: int = 1,
sample_indices: Optional[torch.Tensor] = None,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
_save_modified_probs: bool = False, # pylint: disable=invalid-name
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Sample tokens from probs. with per-sequence seeds.
Can sample from a subset of sequences through sample_indices.
Args:
probs: Probabilities to sample from.
shape = [batch_size, vocab_size]
seeds: Per-sequence seed values.
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
max_best_of: Number of samples to generate per sequence.
Sequence seed will be incremented by 1 each time.
sample_indices: Indices of sequences to sample from.
If not provided, will sample from all sequences.
shape = [n]
logprobs: Log-probabilities of the sampled tokens.
Only used for saving the logprobs if save_logprobs is True.
shape = [batch_size, vocab_size]
modify_greedy_probs: Whether to modify the greedy probabilities
for speculative sampling (sampled token = 1.0,
everything else = 0.0).
save_logprobs: Whether to save the log-probabilities of the
sampled tokens to a tensor.
_save_modified_probs: Whether to save the modified probabilities
(including gumbel noise) of the sampled tokens to a tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
This is exposed only for testing.
Returns:
sampled_tokens: shape = [n, max_best_of]
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
sampled_modified_probs: shape = [n, max_best_of]
if save_modified_probs else None
"""
if sample_indices is None:
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
sampled_tokens_size = (sample_indices.size(0), max_best_of)
if save_logprobs:
if logprobs is None:
raise ValueError(
"logprobs tensor must be provided if save_logprobs is True")
sampled_logprobs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_logprobs_size = (0, 0)
logprobs = probs
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_modified_probs_size = (0, 0)
# If the number of columns in probs is too large for Triton to handle,
# we split the tensor and sample from each split separately, and then
# do an argmax+gather to combine the results.
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if n_splits > 1:
(sampled_tokens, sampled_logprobs,
sampled_modified_probs) = _multi_split_sample(
probs,
seeds,
n_splits,
sampled_tokens_size,
sampled_logprobs_size,
sample_indices,
logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs)
else:
sampled_tokens = torch.empty(sampled_tokens_size,
dtype=torch.long,
device=probs.device)
sampled_logprobs = torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device)
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
dtype=probs.dtype,
device=probs.device)
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
uniform_noise = seeded_uniform(n_samples,
max_best_of,
n_cols,
seeds=seeds.flatten(),
device=probs.device,
dtype=probs.dtype)
_sample(
probs,
logprobs,
sample_indices,
sampled_tokens,
sampled_logprobs,
sampled_modified_probs,
seeds,
uniform_noise,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=_save_modified_probs,
)
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
sampled_modified_probs if _save_modified_probs else None)
def _sample(probs: torch.Tensor,
logprobs: torch.Tensor,
sample_indices: torch.Tensor,
output_samples: torch.Tensor,
output_logprobs: torch.Tensor,
output_modified_probs: torch.Tensor,
seeds: torch.Tensor,
uniform_noise: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = True,
save_modified_probs: bool = False) -> torch.Tensor:
"""Sample tokens from probs.
Args:
probs [batch_size, vocab_size]: probs to sample from.
logprobs [batch_size, vocab_size]: logprobs (used when
save_logprobsis True).
sample_indices [n]: Indices of the samples to use for each row of probs.
output_samples [n, n_best]: Output tensor to store samples in.
output_logprobs [n, n_best]: Output tensor to store logprobs in.
output_modified_probs [n, n_best]: Output tensor to store
probs of chosen tokens in (modified with noise).
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
greedy sampling. Note this is ONLY used for determining
whether to use random sampling or not. The actual random
noise should be passed as uniform_noise.
uniform_noise [batch_size, n_best, vocab_size]: Uniform
noise to use for random sampling (will be converted
to exponential gumbel noise by the kernel).
modify_greedy_probs: If True, we modify the probs tensor in-place
to encode the sampling method used for each row. This is used
in speculative decoding. Only applies in greedy decoding.
save_logprobs: If True, we save the logprobs of the sampled tokens
in the output_logprobs tensor.
save_modified_probs: If True, we save the modified probs (with noise)
of the sampled tokens in the output_modified_probs tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
"""
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
# The block size is the smallest power of two greater than the number of
# columns in probs
block_size = triton.next_power_of_2(n_cols)
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if block_size >= 8192:
num_warps = 32
elif block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
# instance per row of the probs matrix
_sample_triton[(n_samples, n_best)](
sample_indices,
output_samples,
output_logprobs,
output_modified_probs,
probs,
logprobs,
seeds,
uniform_noise,
output_samples.stride(0),
probs.stride(0),
uniform_noise.stride(0),
uniform_noise.stride(1) if n_best > 1 else 1,
n_samples,
n_cols,
n_best,
num_warps=num_warps,
block_size=block_size,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=save_modified_probs,
)
return output_samples, output_logprobs, output_modified_probs
@triton.jit
def _uniform_to_exponential(uniform_noise):
"""Convert uniform samples to exponential samples."""
# tl.rand returns values in [0, 1), so we clamp lower bound
# to _EPS to avoid log(0) and thus division by 0 later
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
uniform_noise = tl.maximum(uniform_noise, lb)
# Use the inversion method to turn uniform samples
# into exponential samples
exponential_noise = -tl.log(uniform_noise)
return exponential_noise
@triton.jit
def _sample_triton(
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
output_logprobs_ptr: torch.Tensor,
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
probs_row_stride: int, uniform_noise_row_stride: int,
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
n_best: int, block_size: tl.constexpr,
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
save_modified_probs: tl.constexpr):
# The rows are independent, so we parallelize across those
sample_idx = tl.program_id(0)
best_idx = tl.program_id(1)
# Load the row index from DRAM
row_idx = tl.load(sample_indices_ptr + sample_idx)
seed = tl.load(seeds_ptr + sample_idx)
uses_random_sampling = seed != 0
# The stride represents how much we need to increase the
# pointer to advance 1 row
row_start_ptr = probs_ptr + row_idx * probs_row_stride
# The block size is the next power of two greater than n_cols,
# so we can fit each row in a single block
col_offsets = tl.arange(0, block_size)
# Load the row into SRAM, using a mask since block_size may be > than n_cols
row = tl.load(row_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=float("-inf"))
if uses_random_sampling:
uniform_noise_start_ptr = (uniform_noise_ptr +
sample_idx * uniform_noise_row_stride +
best_idx * uniform_noise_best_stride)
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=0.5)
exponential_noise = _uniform_to_exponential(uniform_noise)
row /= exponential_noise
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
# clamp sampled token to n_cols - 1
# this should not be necessary, but we do it
# just in case
if sampled_token >= n_cols:
sampled_token = n_cols - 1
# Write back output to DRAM
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
best_idx)
tl.store(output_row_start_ptr, sampled_token)
if modify_greedy_probs: # noqa
if not uses_random_sampling:
# Set the probability of the sampled token to 1, all other
# tokens to zero. This is used in speculative decoding where
# the sampling method must be encoded within the sampled
# probability distributions.
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
tl.store(row_start_ptr + col_offsets,
row,
mask=col_offsets < n_cols)
if save_modified_probs:
output_row_start_ptr = (output_modified_probs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_value)
if save_logprobs:
# Load the row into SRAM, using a mask since block_size
# may be > than n_cols
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
sampled_token)
# Write back output to DRAM
output_row_start_ptr = (output_logprobs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_logprob)
from typing import Type from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
_QUANTIZATION_CONFIG_REGISTRY = { _QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig, "awq": AWQConfig,
......
...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter ...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
...@@ -50,7 +51,8 @@ class AWQConfig(QuantizationConfig): ...@@ -50,7 +51,8 @@ class AWQConfig(QuantizationConfig):
def get_config_filenames() -> List[str]: def get_config_filenames() -> List[str]:
return [ return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
"quantize_config.json",
] ]
@classmethod @classmethod
......
import enum import enum
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional
from fractions import Fraction from fractions import Fraction
from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -31,8 +31,8 @@ class GPTQConfig(QuantizationConfig): ...@@ -31,8 +31,8 @@ class GPTQConfig(QuantizationConfig):
self.pack_factor = Fraction(32, self.weight_bits) self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]: if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError( raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is supported for " "Currently, only 2/3/4/8-bit weight quantization is "
f"GPTQ, but got {self.weight_bits} bits.") f"supported for GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, " return (f"GPTQConfig(weight_bits={self.weight_bits}, "
...@@ -101,7 +101,8 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -101,7 +101,8 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
...@@ -114,7 +115,8 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -114,7 +115,8 @@ class GPTQLinearMethod(LinearMethodBase):
exllama_state = ExllamaState.UNINITIALIZED exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None scale_and_zero_input_dim = None
if input_size != input_size_per_partition and self.quant_config.group_size != -1: if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer # For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act: if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED exllama_state = ExllamaState.UNUSED
......
...@@ -4,8 +4,10 @@ import torch ...@@ -4,8 +4,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class MarlinConfig(QuantizationConfig): class MarlinConfig(QuantizationConfig):
...@@ -22,8 +24,9 @@ class MarlinConfig(QuantizationConfig): ...@@ -22,8 +24,9 @@ class MarlinConfig(QuantizationConfig):
self.group_size = group_size self.group_size = group_size
if self.group_size != 128 and self.group_size != -1: if self.group_size != 128 and self.group_size != -1:
raise ValueError( raise ValueError(
"Currently, only group size 128 and -1 (channelwise) is supported for " "Currently, only group size 128 and -1 (channelwise) "
f"Marlin, but got group_size of {self.group_size}") "is supported for Marlin, but got group_size of "
f"{self.group_size}")
# 4 Bits packed into 32 bit datatype. # 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4 self.pack_factor = 32 // 4
...@@ -37,14 +40,15 @@ class MarlinConfig(QuantizationConfig): ...@@ -37,14 +40,15 @@ class MarlinConfig(QuantizationConfig):
# Min in_features dim # Min in_features dim
self.min_k_threads = 128 self.min_k_threads = 128
# Max parallel problems to solve at once (improves large batch performance) # Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16 self.max_parallel = 16
# Permutation length used by the marlin kernels. # Permutation length used by the marlin kernels.
self.perm_len = 1024 self.perm_len = 1024
def __repr__(self) -> str: def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size}" return f"MarlinConfig(group_size={self.group_size})"
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -102,22 +106,26 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -102,22 +106,26 @@ class MarlinLinearMethod(LinearMethodBase):
# Validate output_size_per_partition # Validate output_size_per_partition
if output_size_per_partition % self.quant_config.min_n_threads != 0: if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError( raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}." f"Weight output_size_per_partition = "
) f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}." f"Weight output_size_per_partition = "
) f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition # Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0: if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError( raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}." f"Weight input_size_per_partition = "
) f"{input_size_per_partition} is not divisible by "
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0: f"min_k_threads = {self.quant_config.min_k_threads}.")
raise ValueError( if (self.quant_config.group_size != -1 and
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}." input_size_per_partition % self.quant_config.group_size != 0):
) raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard # Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // ( num_tiles_per_perm = self.quant_config.perm_len // (
...@@ -149,7 +157,9 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -149,7 +157,9 @@ class MarlinLinearMethod(LinearMethodBase):
) )
# Determine if channelwise or not # Determine if channelwise or not
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
scales = Parameter( scales = Parameter(
torch.empty( torch.empty(
......
...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter ...@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import is_hip from vllm.utils import is_hip
......
from typing import Tuple, Optional
from functools import cached_property from functools import cached_property
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn
import torch.jit import torch.jit
import torch.nn as nn
class RejectionSampler(nn.Module): class RejectionSampler(nn.Module):
...@@ -21,8 +21,6 @@ class RejectionSampler(nn.Module): ...@@ -21,8 +21,6 @@ class RejectionSampler(nn.Module):
nontrivial latency. nontrivial latency.
""" """
super().__init__() super().__init__()
self.probs_dtype = torch.float32
self.token_id_dtype = torch.int64
self._strict_mode = strict_mode self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are # NOTE: A "bonus token" is accepted iff all proposal tokens are
...@@ -44,6 +42,14 @@ class RejectionSampler(nn.Module): ...@@ -44,6 +42,14 @@ class RejectionSampler(nn.Module):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
def forward( def forward(
self, self,
target_probs: torch.Tensor, target_probs: torch.Tensor,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Rotary Positional Embeddings.""" """Rotary Positional Embeddings."""
import math import math
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -96,6 +96,7 @@ class RotaryEmbedding(nn.Module): ...@@ -96,6 +96,7 @@ class RotaryEmbedding(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
query = query.view(*query.shape[:-1], -1, self.head_size) query = query.view(*query.shape[:-1], -1, self.head_size)
...@@ -107,7 +108,9 @@ class RotaryEmbedding(nn.Module): ...@@ -107,7 +108,9 @@ class RotaryEmbedding(nn.Module):
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
cos_sin = self.cos_sin_cache[positions] self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style: if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the # NOTE(woosuk): Here we assume that the positions tensor has the
...@@ -137,11 +140,19 @@ class RotaryEmbedding(nn.Module): ...@@ -137,11 +140,19 @@ class RotaryEmbedding(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# ops.rotary_embedding() is an in-place operation that self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
# updates the query and key tensors. # ops.rotary_embedding()/batched_rotary_embedding()
ops.rotary_embedding(positions, query, key, self.head_size, # are in-place operations that update the query and key tensors.
self.cos_sin_cache, self.is_neox_style) if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
...@@ -158,27 +169,32 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -158,27 +169,32 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factors: Union[List[float], float],
) -> None: ) -> None:
self.scaling_factor = scaling_factor if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors = scaling_factors
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style) is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base) inv_freq = self._compute_inv_freq(self.base)
# NOTE(woosuk): self.max_position_embeddings is the original cache_list = []
# maximum length before applying the rope scaling. for scaling_factor in self.scaling_factors:
# Thus, the maximum length after applying the rope scaling is # NOTE(woosuk): self.max_position_embeddings is the original
# self.max_position_embeddings * self.scaling_factor. # maximum length before applying the rope scaling.
max_len = self.max_position_embeddings * self.scaling_factor # Thus, the maximum length after applying the rope scaling is
t = torch.arange(max_len, dtype=torch.float) # self.max_position_embeddings * self.scaling_factor.
t = t / self.scaling_factor max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) t = t / scaling_factor
cos = freqs.cos()
sin = freqs.sin() freqs = torch.einsum("i,j -> ij", t, inv_freq)
cache = torch.cat((cos, sin), dim=-1) cos = freqs.cos()
return cache sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
cache_list.append(cache)
return torch.cat(cache_list, dim=0)
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
......
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.layers.ops.sample import sample as sample_triton
tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import (SamplingMetadata,
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SequenceData, SequenceGroupOutput, SequenceOutput) SamplerOutput, SequenceData, SequenceGroupOutput,
from vllm.utils import is_neuron SequenceOutput)
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -28,57 +29,17 @@ class Sampler(nn.Module): ...@@ -28,57 +29,17 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.). parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
""" """
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward( def forward(
self, self,
embedding: torch.Tensor, logits: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling.
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if not sampling_metadata.perform_sampling:
return None
assert logits is not None assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Apply logits processors (if any). # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
logits = _apply_logits_processors(logits, sampling_metadata) # have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking. # Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k, (sampling_tensors, do_penalties, do_top_p_top_k,
...@@ -112,7 +73,8 @@ class Sampler(nn.Module): ...@@ -112,7 +73,8 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata) sample_results = _sample(probs, logprobs, sampling_metadata,
sampling_tensors)
# Get the logprobs query results. # Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs( prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results) logprobs, sampling_metadata, sample_results)
...@@ -120,15 +82,6 @@ class Sampler(nn.Module): ...@@ -120,15 +82,6 @@ class Sampler(nn.Module):
prompt_logprobs, sample_logprobs) prompt_logprobs, sample_logprobs)
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _get_bin_counts_and_mask( def _get_bin_counts_and_mask(
tokens: torch.Tensor, tokens: torch.Tensor,
vocab_size: int, vocab_size: int,
...@@ -146,27 +99,39 @@ def _get_bin_counts_and_mask( ...@@ -146,27 +99,39 @@ def _get_bin_counts_and_mask(
return bin_counts, mask return bin_counts, mask
def _apply_logits_processors( def _apply_min_tokens_penalty(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
logits_row_idx = 0 # list of indices in logits that will be set to -inf
found_logits_processors = False logits_to_penalize = []
start_idx = 0
for seq_ids, sampling_params in sampling_metadata.seq_groups: for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors min_tokens = sampling_params.min_tokens
if logits_processors: if min_tokens > 0:
found_logits_processors = True seqs_to_penalize = []
for seq_id in seq_ids: for i, seq_id in enumerate(seq_ids):
logits_row = logits[logits_row_idx] seq_data = sampling_metadata.seq_data[seq_id]
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids if len(seq_data.output_token_ids) < min_tokens:
for logits_processor in logits_processors: seqs_to_penalize.append(i)
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row if seqs_to_penalize:
logits_row_idx += 1 # convert to the index into logits
else: seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
logits_row_idx += len(seq_ids) # use set() to remove any duplicates
if found_logits_processors: token_ids_to_penalize = set(sampling_params.stop_token_ids +
assert logits_row_idx == logits.shape[0] [sampling_params.eos_token_id])
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
start_idx += len(seq_ids)
if logits_to_penalize:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
return logits return logits
...@@ -373,7 +338,7 @@ def _multinomial( ...@@ -373,7 +338,7 @@ def _multinomial(
return probs.div_(q).argmax(dim=1).view(-1, num_samples) return probs.div_(q).argmax(dim=1).view(-1, num_samples)
def _sample( def _sample_with_torch(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
...@@ -392,7 +357,7 @@ def _sample( ...@@ -392,7 +357,7 @@ def _sample(
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync. # The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType: for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type] sample_indices = categorized_sample_indices[sampling_type][:, 0]
num_tokens = len(sample_indices) num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
continue continue
...@@ -405,17 +370,19 @@ def _sample( ...@@ -405,17 +370,19 @@ def _sample(
greedy_samples = torch.argmax(logprobs[sample_indices.long()], greedy_samples = torch.argmax(logprobs[sample_indices.long()],
dim=-1) dim=-1)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1 max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt: if is_prompt:
_, sampling_params = seq_group _, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of) max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else { seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups, "seq_groups": seq_groups,
"generators": sampling_metadata.generators, "generators": sampling_metadata.generators,
} }
multinomial_samples[sampling_type] = _multinomial( multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices.long()], max_best_of, **seeded_args) probs[sample_indices.long()], max_best_of_in_batch,
**seeded_args)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices] beam_search_logprobs = logprobs[sample_indices]
else: else:
...@@ -446,6 +413,118 @@ def _sample( ...@@ -446,6 +413,118 @@ def _sample(
return sample_results return sample_results
def _sample_with_triton_kernel(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {}
max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sampled_tokens, _, _ = sample_triton(
probs=probs,
seeds=sampling_tensors.sampling_seeds,
max_best_of=max_best_of_in_batch,
sample_indices=sampling_tensors.sample_indices,
logprobs=logprobs,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs=False,
)
# GPU<->CPU sync happens in the loop below.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_ids, seq_groups, is_prompts, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results = [
sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
return _sample_with_torch(probs, logprobs, sampling_metadata)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.
Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
indices (torch.Tensor): List of chosen token indices.
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
indices]
return (x > vals[:, None]).long().sum(1).add_(1)
def _get_logprobs( def _get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
...@@ -455,7 +534,8 @@ def _get_logprobs( ...@@ -455,7 +534,8 @@ def _get_logprobs(
# Prepare query indices # Prepare query indices
batched_logprobs_query_seq_indices: List[int] = [] batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = [] batched_logprobs_query_token_indices: List[int] = []
largest_num_logprobs = 0 # at least get one logprob for each token
largest_num_logprobs = 1
sample_idx = 0 sample_idx = 0
for i, (seq_group, sample_result) in enumerate( for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)): zip(sampling_metadata.seq_groups, sample_results)):
...@@ -483,12 +563,21 @@ def _get_logprobs( ...@@ -483,12 +563,21 @@ def _get_logprobs(
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0) assert sample_idx == logprobs.size(0)
batched_logprobs_query_seq_indices_gpu = torch.tensor(
batched_logprobs_query_seq_indices, device=logprobs.device)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device)
# Batched query for logprobs of selected token # Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[ batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices, batched_logprobs_query_seq_indices_gpu,
batched_logprobs_query_token_indices batched_logprobs_query_token_indices_gpu
]] ]]
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices_gpu],
batched_logprobs_query_token_indices_gpu)
# Batched query for logprobs of topk tokens # Batched query for logprobs of topk tokens
if largest_num_logprobs > 0: if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs, top_logprobs, top_token_ids = torch.topk(logprobs,
...@@ -500,6 +589,7 @@ def _get_logprobs( ...@@ -500,6 +589,7 @@ def _get_logprobs(
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu() batched_logprobs_query_result = batched_logprobs_query_result.cpu()
batched_ranks_query_result = batched_ranks_query_result.cpu()
# Gather results # Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
...@@ -515,20 +605,27 @@ def _get_logprobs( ...@@ -515,20 +605,27 @@ def _get_logprobs(
if (i < sampling_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs num_logprobs = sampling_params.prompt_logprobs
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = sampling_metadata.seq_data[ prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None] group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]: for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = { prompt_logprobs_dict = {
token_id: token_id:
batched_logprobs_query_result[query_result_idx].item() (batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
} }
if num_logprobs > 0: if num_logprobs > 0:
prompt_logprobs_dict.update( prompt_logprobs_dict.update(
zip(top_token_ids[sample_idx, :num_logprobs].tolist(), zip(
top_logprobs[sample_idx, :num_logprobs].tolist())) top_token_ids[sample_idx, :num_logprobs].tolist(),
group_prompt_logprobs.append(prompt_logprobs_dict) zip(
top_logprobs[
sample_idx, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_prompt_logprobs.append({
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in prompt_logprobs_dict.items()
})
sample_idx += 1 sample_idx += 1
query_result_idx += 1 query_result_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs) result_prompt_logprobs.append(group_prompt_logprobs)
...@@ -543,17 +640,23 @@ def _get_logprobs( ...@@ -543,17 +640,23 @@ def _get_logprobs(
for next_token_id, parent_id in zip(next_token_ids, parent_ids): for next_token_id, parent_id in zip(next_token_ids, parent_ids):
sample_logprobs_dict = { sample_logprobs_dict = {
next_token_id: next_token_id:
batched_logprobs_query_result[query_result_idx].item() (batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
} }
query_result_idx += 1 query_result_idx += 1
if num_logprobs > 0: if num_logprobs >= 0:
sample_logprobs_dict.update( sample_logprobs_dict.update(
zip( zip(
top_token_ids[sample_idx + top_token_ids[sample_idx +
parent_id, :num_logprobs].tolist(), parent_id, :num_logprobs].tolist(),
top_logprobs[sample_idx + zip(
parent_id, :num_logprobs].tolist())) top_logprobs[sample_idx +
group_sample_logprobs.append(sample_logprobs_dict) parent_id, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_sample_logprobs.append({
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in sample_logprobs_dict.items()
})
result_sample_logprobs.append(group_sample_logprobs) result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids) sample_idx += len(seq_ids)
...@@ -581,4 +684,4 @@ def _build_sampler_output( ...@@ -581,4 +684,4 @@ def _build_sampler_output(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append( sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
return sampler_output return SamplerOutput(outputs=sampler_output)
...@@ -4,13 +4,11 @@ import torch ...@@ -4,13 +4,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
......
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
from typing import Type from typing import Tuple, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, ModelConfig from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
...@@ -20,7 +25,8 @@ def _set_default_torch_dtype(dtype: torch.dtype): ...@@ -20,7 +25,8 @@ def _set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype) torch.set_default_dtype(old_dtype)
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: def _get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
...@@ -31,16 +37,21 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: ...@@ -31,16 +37,21 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
for arch in architectures: for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch) model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None: if model_cls is not None:
return model_cls return (model_cls, arch)
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}") f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_architecture_class_name(model_config: ModelConfig) -> str:
return _get_model_architecture(model_config)[1]
def get_model(model_config: ModelConfig, device_config: DeviceConfig, def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module: **kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None) lora_config = kwargs.get("lora_config", None)
model_class = _get_model_architecture(model_config) vision_language_config = kwargs.get("vision_language_config", None)
model_class = _get_model_architecture(model_config)[0]
# Get the (maybe quantized) linear method. # Get the (maybe quantized) linear method.
linear_method = None linear_method = None
...@@ -76,7 +87,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, ...@@ -76,7 +87,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
"be added in the future. If this is important to you, " "be added in the future. If this is important to you, "
"please open an issue on github.") "please open an issue on github.")
else: else:
model = model_class(model_config.hf_config, linear_method) if model_class not in _VISION_MODEL_CLASSES:
model = model_class(model_config.hf_config, linear_method)
else:
model = model_class(model_config.hf_config,
vision_language_config, linear_method)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
......
...@@ -4,7 +4,7 @@ from typing import List, Optional, Type ...@@ -4,7 +4,7 @@ from typing import List, Optional, Type
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip, is_neuron from vllm.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -17,6 +17,8 @@ _MODELS = { ...@@ -17,6 +17,8 @@ _MODELS = {
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
...@@ -27,7 +29,10 @@ _MODELS = { ...@@ -27,7 +29,10 @@ _MODELS = {
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
...@@ -42,10 +47,12 @@ _MODELS = { ...@@ -42,10 +47,12 @@ _MODELS = {
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
} }
# Models not supported by ROCm. # Models not supported by ROCm.
...@@ -62,9 +69,6 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = { ...@@ -62,9 +69,6 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention", "Sliding window attention is not yet supported in ROCm's flash attention",
} }
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
class ModelRegistry: class ModelRegistry:
...@@ -81,15 +85,8 @@ class ModelRegistry: ...@@ -81,15 +85,8 @@ class ModelRegistry:
logger.warning( logger.warning(
f"Model architecture {model_arch} is partially supported " f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
elif is_neuron():
if model_arch not in _NEURON_SUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"Neuron for now.")
module_name, model_cls_name = _MODELS[model_arch] module_name, model_cls_name = _MODELS[model_arch]
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
module = importlib.import_module( module = importlib.import_module(
f"vllm.model_executor.models.{module_name}") f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None) return getattr(module, model_cls_name, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment