Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a130cf33
Commit
a130cf33
authored
Mar 06, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx
parents
a2d181be
82091b86
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
709 additions
and
72 deletions
+709
-72
vllm/lora/punica.py
vllm/lora/punica.py
+1
-1
vllm/model_executor/__init__.py
vllm/model_executor/__init__.py
+1
-2
vllm/model_executor/guided_decoding.py
vllm/model_executor/guided_decoding.py
+99
-0
vllm/model_executor/guided_logits_processors.py
vllm/model_executor/guided_logits_processors.py
+129
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+23
-0
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+18
-16
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+5
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
...configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
+20
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
...configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
+24
-0
vllm/model_executor/layers/fused_moe/configs/README
vllm/model_executor/layers/fused_moe/configs/README
+10
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+66
-16
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+29
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+9
-7
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+210
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-5
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+13
-5
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
+28
-13
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+5
-5
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+15
-2
No files found.
vllm/lora/punica.py
View file @
a130cf33
...
@@ -87,7 +87,7 @@ def add_lora(y: torch.Tensor,
...
@@ -87,7 +87,7 @@ def add_lora(y: torch.Tensor,
r
=
wb_t_all
.
size
(
-
1
)
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
if
buffer
is
None
:
# We set the buffer to be float32 by default to avoid
# We set the buffer to be float32 by default to avoid
# numerical in
n
acuracies that would otherwise happen
# numerical ina
c
curacies that would otherwise happen
# due to downcasting.
# due to downcasting.
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
...
vllm/model_executor/__init__.py
View file @
a130cf33
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
,
get_model
__all__
=
[
__all__
=
[
"InputMetadata"
,
"InputMetadata"
,
...
...
vllm/model_executor/guided_decoding.py
0 → 100644
View file @
a130cf33
import
asyncio
import
concurrent.futures
from
copy
import
copy
from
enum
import
Enum
from
functools
import
lru_cache
from
json
import
dumps
as
json_dumps
from
re
import
escape
as
regex_escape
from
typing
import
Union
,
Tuple
from
pydantic
import
BaseModel
from
vllm.entrypoints.openai.protocol
import
CompletionRequest
,
ChatCompletionRequest
from
vllm.model_executor.guided_logits_processors
import
JSONLogitsProcessor
,
RegexLogitsProcessor
class
GuidedDecodingMode
(
Enum
):
JSON
=
"json"
REGEX
=
"regex"
CHOICE
=
"choice"
global_thread_pool
=
None
# used for generating logits processor fsm
async
def
get_guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
global
global_thread_pool
guide
,
mode
=
_get_guide_and_mode
(
request
)
if
not
guide
:
return
None
if
global_thread_pool
is
None
:
global_thread_pool
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
loop
=
asyncio
.
get_running_loop
()
result
=
await
loop
.
run_in_executor
(
global_thread_pool
,
_get_cached_logits_processor
,
guide
,
tokenizer
,
mode
)
logits_processor
=
copy
(
result
)
# reset logits processor's internal state
logits_processor
.
init_state
()
return
logits_processor
def
_get_guide_and_mode
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
)
->
Tuple
[
str
,
GuidedDecodingMode
]:
if
request
.
guided_json
:
if
not
isinstance
(
request
.
guided_json
,
(
str
,
dict
,
BaseModel
)):
raise
TypeError
(
"JSON schema must be str, dict, or BaseModel"
)
json
=
request
.
guided_json
if
isinstance
(
json
,
dict
):
# turn dict into hashable string
json
=
json_dumps
(
json
,
sort_keys
=
True
)
elif
isinstance
(
json
,
BaseModel
):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json
=
str
(
json
.
__signature__
)
return
json
,
GuidedDecodingMode
.
JSON
elif
request
.
guided_regex
:
if
not
isinstance
(
request
.
guided_regex
,
str
):
raise
TypeError
(
"Regex must be string"
)
return
request
.
guided_regex
,
GuidedDecodingMode
.
REGEX
elif
request
.
guided_choice
:
if
not
isinstance
(
request
.
guided_choice
,
list
):
raise
TypeError
(
"Choices must be a list"
)
# choice just uses regex
choices
=
[
regex_escape
(
str
(
choice
))
for
choice
in
request
.
guided_choice
]
choices_regex
=
"("
+
"|"
.
join
(
choices
)
+
")"
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
else
:
return
None
,
None
@
lru_cache
(
maxsize
=
32
)
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
,
mode
:
GuidedDecodingMode
):
if
mode
==
GuidedDecodingMode
.
JSON
:
return
JSONLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
else
:
raise
ValueError
(
f
"Unknown guided decoding mode
{
mode
}
"
)
vllm/model_executor/guided_logits_processors.py
0 → 100644
View file @
a130cf33
# Copyright 2024- the Outlines developers
# This file is adapted from
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
math
from
collections
import
defaultdict
from
typing
import
Union
,
DefaultDict
,
Dict
,
List
,
Optional
import
torch
from
pydantic
import
BaseModel
from
outlines.fsm.fsm
import
RegexFSM
from
outlines.fsm.json_schema
import
build_regex_from_schema
class
RegexLogitsProcessor
:
def
__init__
(
self
,
regex_string
:
str
,
tokenizer
):
"""Compile the FSM that drives the regex-structured generation.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
def
init_state
(
self
):
"""Initialize the FSM states."""
self
.
fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id
=
hash
(
tuple
(
input_ids
))
if
len
(
input_ids
)
==
0
:
self
.
init_state
()
else
:
last_token
=
input_ids
[
-
1
]
last_seq_id
=
hash
(
tuple
(
input_ids
[:
-
1
]))
self
.
fsm_state
[
seq_id
]
=
self
.
fsm
.
next_state
(
self
.
fsm_state
[
last_seq_id
],
last_token
)
allowed_tokens
=
self
.
fsm
.
allowed_token_ids
(
self
.
fsm_state
[
seq_id
])
mask
=
torch
.
full
((
scores
.
shape
[
-
1
],
),
-
math
.
inf
,
device
=
scores
.
device
)
mask
[
allowed_tokens
]
=
0
scores
.
add_
(
mask
)
return
scores
def
adapt_tokenizer
(
self
,
tokenizer
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
tokenizer
.
convert_token_to_string
=
convert_token_to_string
return
tokenizer
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
def
__init__
(
self
,
schema
:
Union
[
str
,
Dict
,
BaseModel
],
tokenizer
,
whitespace_pattern
:
Optional
[
str
]
=
None
):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
tokenizer
The model's tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[
\n
]?"`
"""
if
isinstance
(
schema
,
type
(
BaseModel
)):
schema_str
=
json
.
dumps
(
schema
.
model_json_schema
())
elif
isinstance
(
schema
,
Dict
):
schema_str
=
json
.
dumps
(
schema
)
elif
isinstance
(
schema
,
str
):
schema_str
=
schema
else
:
raise
ValueError
(
f
"Cannot parse schema
{
schema
}
. The schema must be either "
+
"a Pydantic object, a dictionary or a string that contains the JSON "
+
"Schema specification"
)
regex_string
=
build_regex_from_schema
(
schema_str
,
whitespace_pattern
)
super
().
__init__
(
regex_string
,
tokenizer
)
vllm/model_executor/layers/activation.py
View file @
a130cf33
...
@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
...
@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
return
out
return
out
class
GeluAndMul
(
nn
.
Module
):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
gelu_and_mul
(
out
,
x
)
return
out
class
NewGELU
(
nn
.
Module
):
class
NewGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/attention.py
View file @
a130cf33
...
@@ -137,25 +137,27 @@ class PagedAttention(nn.Module):
...
@@ -137,25 +137,27 @@ class PagedAttention(nn.Module):
)
)
if
input_metadata
.
is_prompt
:
if
input_metadata
.
is_prompt
:
# Prompt run.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# normal attention
# normal attention
if
(
key_cache
is
None
or
value_cache
is
None
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
or
input_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
# FIXME(woosuk): This is a hack.
...
...
vllm/model_executor/layers/fused_moe/__init__.py
0 → 100644
View file @
a130cf33
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
__all__
=
[
"fused_moe"
,
]
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
0 → 100644
View file @
a130cf33
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"192"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
a130cf33
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"80"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"192"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"200"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"208"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"216"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"224"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/README
0 → 100644
View file @
a130cf33
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
vllm/model_executor/layers/fused_moe.py
→
vllm/model_executor/layers/fused_moe
/fused_moe
.py
View file @
a130cf33
"""Fused MoE kernel."""
"""Fused MoE kernel."""
import
functools
import
json
import
os
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm._C
import
ops
from
vllm._C
import
ops
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel
(
def
fused_moe_kernel
(
...
@@ -129,7 +137,7 @@ def fused_moe_kernel(
...
@@ -129,7 +137,7 @@ def fused_moe_kernel(
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
)
:
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
"""
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
...
@@ -177,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
...
@@ -177,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
dict
):
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
])
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
@@ -210,6 +219,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
...
@@ -210,6 +219,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)
)
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs directory
device_name
=
torch
.
cuda
.
get_device_name
().
replace
(
" "
,
"_"
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}
.json"
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
f
"Using configuration from
{
config_file_path
}
for MoE layer."
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default configuration
return
None
def
fused_moe
(
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -218,6 +255,7 @@ def fused_moe(
...
@@ -218,6 +255,7 @@ def fused_moe(
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...
@@ -230,6 +268,7 @@ def fused_moe(
...
@@ -230,6 +268,7 @@ def fused_moe(
- topk (int): The number of top-k experts to select.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
Returns:
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
...
@@ -279,20 +318,31 @@ def fused_moe(
...
@@ -279,20 +318,31 @@ def fused_moe(
if
renormalize
:
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
config
=
{
if
override_config
:
'BLOCK_SIZE_M'
:
64
,
config
=
override_config
'BLOCK_SIZE_N'
:
64
,
else
:
'BLOCK_SIZE_K'
:
32
,
# First try to load optimal config from the file
'GROUP_SIZE_M'
:
8
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
])
}
if
configs
:
if
topk_ids
.
numel
()
<=
w1
.
shape
[
0
]:
# If an optimal configuration map has been found, look up the optimal config
config
=
{
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
'BLOCK_SIZE_M'
:
16
,
else
:
'BLOCK_SIZE_N'
:
32
,
# Else use the default config
'BLOCK_SIZE_K'
:
64
,
config
=
{
'GROUP_SIZE_M'
:
1
'BLOCK_SIZE_M'
:
64
,
}
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
...
...
vllm/model_executor/layers/linear.py
View file @
a130cf33
...
@@ -17,6 +17,14 @@ from vllm.logger import init_logger
...
@@ -17,6 +17,14 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
if
marlin_tile_size
is
None
:
return
shard_size
,
shard_offset
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
class
LinearMethodBase
(
ABC
):
class
LinearMethodBase
(
ABC
):
"""Base class for different (maybe quantized) linear methods."""
"""Base class for different (maybe quantized) linear methods."""
...
@@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
...
@@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
...
@@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id
:
Optional
[
str
]
=
None
):
loaded_shard_id
:
Optional
[
str
]
=
None
):
param_data
=
param
.
data
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
# Loaded weight is already packed.
if
output_dim
is
None
:
if
output_dim
is
None
:
...
@@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
...
@@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
if
loaded_shard_id
==
"q"
:
if
loaded_shard_id
==
"q"
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
a130cf33
...
@@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
...
@@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY
=
{
_QUANTIZATION_CONFIG_REGISTRY
=
{
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"marlin"
:
MarlinConfig
,
}
}
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
a130cf33
import
enum
import
enum
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
from
fractions
import
Fraction
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
...
@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
desc_act
=
desc_act
self
.
pack_factor
=
32
//
self
.
weight_bits
self
.
pack_factor
=
Fraction
(
32
,
self
.
weight_bits
)
# exllama kernel v1 only supports 4 bit
if
self
.
weight_bits
not
in
[
2
,
3
,
4
,
8
]:
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
raise
ValueError
(
"Currently, only
4
-bit weight quantization is supported for "
"Currently, only
2/3/4/8
-bit weight quantization is supported for "
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
"tensor parallel size."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
.
numerator
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The output size is not aligned with the quantized "
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"weight shape. This can be caused by too large "
...
@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else
:
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
])
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
self
.
quant_config
.
weight_bits
)
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
)
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
self
.
quant_config
.
weight_bits
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
.
reshape
(
out_shape
)
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/marlin.py
0 → 100644
View file @
a130cf33
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm._C
import
ops
from
vllm.model_executor.layers.linear
import
LinearMethodBase
,
set_weight_attrs
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
class
MarlinConfig
(
QuantizationConfig
):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def
__init__
(
self
,
group_size
:
int
,
)
->
None
:
# Group size for the quantization.
self
.
group_size
=
group_size
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
raise
ValueError
(
"Currently, only group size 128 and -1 (channelwise) is supported for "
f
"Marlin, but got group_size of
{
self
.
group_size
}
"
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
4
# Tile size used by marlin kernels.
self
.
tile_size
=
16
# Min out_features dim
self
.
min_n_threads
=
64
# Min in_features dim
self
.
min_k_threads
=
128
# Max parallel problems to solve at once (improves large batch performance)
self
.
max_parallel
=
16
# Permutation length used by the marlin kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
f
"MarlinConfig(group_size=
{
self
.
group_size
}
"
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"MarlinConfig"
:
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
group_size
)
def
get_linear_method
(
self
)
->
"MarlinLinearMethod"
:
return
MarlinLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
MarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
MarlinConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
del
output_size
# Unused.
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition =
{
output_size_per_partition
}
is not divisible by min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition =
{
output_size_per_partition
}
is not divisible by pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
is not divisible by min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = f
{
input_size_per_partition
}
is not divisible by group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
# Determine if channelwise or not
input_groups
=
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
scales
=
Parameter
(
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"output_dim"
:
1
,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
return
{
"B"
:
qweight
,
"s"
:
scales
,
"workspace"
:
workspace
,
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
weights
[
"B"
]
scales
=
weights
[
"s"
]
workspace
=
weights
[
"workspace"
]
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
marlin_gemm
(
x_2d
,
qweight
,
scales
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/rotary_embedding.py
View file @
a130cf33
...
@@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int,
...
@@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int,
def
_yarn_linear_ramp_mask
(
low
:
float
,
high
:
float
,
dim
:
int
,
def
_yarn_linear_ramp_mask
(
low
:
float
,
high
:
float
,
dim
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
device
:
torch
.
device
)
->
torch
.
Tensor
:
if
low
==
high
:
if
low
==
high
:
high
+=
0.001
# Prevent singularity
high
+=
0.001
# Prevent singularity
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
,
device
=
device
)
-
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
)
-
low
)
/
(
high
-
low
)
low
)
/
(
high
-
low
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
return
ramp_func
...
@@ -356,7 +354,6 @@ def get_rope(
...
@@ -356,7 +354,6 @@ def get_rope(
elif
scaling_type
==
"yarn"
:
elif
scaling_type
==
"yarn"
:
original_max_position
=
rope_scaling
[
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
"original_max_position_embeddings"
]
assert
max_position
==
original_max_position
*
scaling_factor
extra_kwargs
=
{
extra_kwargs
=
{
k
:
v
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
for
k
,
v
in
rope_scaling
.
items
()
...
...
vllm/model_executor/layers/sampler.py
View file @
a130cf33
...
@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
...
@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.utils
import
is_neuron
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -32,6 +33,8 @@ class Sampler(nn.Module):
...
@@ -32,6 +33,8 @@ class Sampler(nn.Module):
org_vocab_size
:
Optional
[
int
]
=
None
)
->
None
:
org_vocab_size
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
# Transformers-neuronx generate outputs as logits directly.
self
.
logits_as_hidden_states
=
is_neuron
()
# original vocabulary size (without LoRA).
# original vocabulary size (without LoRA).
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
...
@@ -55,10 +58,14 @@ class Sampler(nn.Module):
...
@@ -55,10 +58,14 @@ class Sampler(nn.Module):
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
# Get the hidden states that we use for sampling.
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
if
self
.
logits_as_hidden_states
:
logits
=
hidden_states
else
:
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
# Only perform sampling in the driver worker.
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# Note: `_get_logits` is still distributed across TP workers because
...
@@ -395,7 +402,8 @@ def _sample(
...
@@ -395,7 +402,8 @@ def _sample(
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
.
long
()],
dim
=-
1
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_best_of
=
1
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
...
@@ -407,7 +415,7 @@ def _sample(
...
@@ -407,7 +415,7 @@ def _sample(
"generators"
:
sampling_metadata
.
generators
,
"generators"
:
sampling_metadata
.
generators
,
}
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
,
**
seeded_args
)
probs
[
sample_indices
.
long
()
],
max_best_of
,
**
seeded_args
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
else
:
...
...
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
View file @
a130cf33
...
@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_d
,
stride_v_cache_bl
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
...
@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
...
@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
l_i
=
l_i_new
m_i
=
m_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
...
@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_d
,
stride_v_cache_bl
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
...
@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
...
@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
l_i
=
l_i_new
m_i
=
m_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
...
@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_d
,
stride_v_cache_bl
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
...
@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
# cur_batch_in_all_start_index: the start id of the dim=0
...
@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
...
@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
k
=
tl
.
load
(
K_cache
+
off_k
,
...
@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
...
@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
l_i
=
l_i_new
m_i
=
m_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
...
@@ -537,7 +549,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -537,7 +549,7 @@ if triton.__version__ >= "2.1.0":
alibi_start_q
=
tl
.
arange
(
alibi_start_q
=
tl
.
arange
(
0
,
BLOCK_M
)
+
block_start_loc
+
cur_batch_ctx_len
0
,
BLOCK_M
)
+
block_start_loc
+
cur_batch_ctx_len
alibi_start_k
=
cur_batch_ctx_len
alibi_start_k
=
cur_batch_ctx_len
# # init debuger
# # init debug
g
er
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
...
@@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0":
sm_scale
=
1.0
/
(
Lq
**
0.5
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
num_queries_per_kv
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
...
@@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
2
),
v_cache
.
stride
(
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
...
@@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
2
),
v_cache
.
stride
(
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
...
...
vllm/model_executor/model_loader.py
View file @
a130cf33
"""Utilities for selecting and loading models."""
"""Utilities for selecting and loading models."""
import
contextlib
import
contextlib
from
typing
import
Optional
,
Type
from
typing
import
Type
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
ModelConfig
,
LoRAConfig
from
vllm.config
import
DeviceConfig
,
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
initialize_dummy_weights
)
...
@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
...
@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
,
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
lora_config
:
Optional
[
LoRAC
onfig
]
=
None
)
->
nn
.
Module
:
lora_config
=
kwargs
.
get
(
"lora_c
onfig
"
,
None
)
model_class
=
_get_model_architecture
(
model_config
)
model_class
=
_get_model_architecture
(
model_config
)
# Get the (maybe quantized) linear method.
# Get the (maybe quantized) linear method.
...
...
vllm/model_executor/models/__init__.py
View file @
a130cf33
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Type
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Type
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
is_neuron
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -30,7 +30,7 @@ _MODELS = {
...
@@ -30,7 +30,7 @@ _MODELS = {
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# For decapoda-research/llama-*
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"
mistral"
,
"Mistral
ForCausalLM"
),
"MistralForCausalLM"
:
(
"
llama"
,
"Llama
ForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
"QuantMixtralForCausalLM"
:
(
"mixtral_quant"
,
"MixtralForCausalLM"
),
"QuantMixtralForCausalLM"
:
(
"mixtral_quant"
,
"MixtralForCausalLM"
),
# transformers's mpt class has lower case
# transformers's mpt class has lower case
...
@@ -38,11 +38,14 @@ _MODELS = {
...
@@ -38,11 +38,14 @@ _MODELS = {
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OLMoForCausalLM"
:
(
"olmo"
,
"OLMoForCausalLM"
),
"OLMoForCausalLM"
:
(
"olmo"
,
"OLMoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"Starcoder2ForCausalLM"
:
(
"starcoder2"
,
"Starcoder2ForCausalLM"
),
}
}
# Models not supported by ROCm.
# Models not supported by ROCm.
...
@@ -59,6 +62,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
...
@@ -59,6 +62,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention"
,
"Sliding window attention is not yet supported in ROCm's flash attention"
,
}
}
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS
=
{
"LlamaForCausalLM"
:
"neuron.llama"
}
class
ModelRegistry
:
class
ModelRegistry
:
...
@@ -75,8 +81,15 @@ class ModelRegistry:
...
@@ -75,8 +81,15 @@ class ModelRegistry:
logger
.
warning
(
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is partially supported "
f
"Model architecture
{
model_arch
}
is partially supported "
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
elif
is_neuron
():
if
model_arch
not
in
_NEURON_SUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture
{
model_arch
}
is not supported by "
"Neuron for now."
)
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
if
is_neuron
():
module_name
=
_NEURON_SUPPORTED_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
module
=
importlib
.
import_module
(
f
"vllm.model_executor.models.
{
module_name
}
"
)
f
"vllm.model_executor.models.
{
module_name
}
"
)
return
getattr
(
module
,
model_cls_name
,
None
)
return
getattr
(
module
,
model_cls_name
,
None
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment