Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a130cf33
Commit
a130cf33
authored
Mar 06, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx
parents
a2d181be
82091b86
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
709 additions
and
72 deletions
+709
-72
vllm/lora/punica.py
vllm/lora/punica.py
+1
-1
vllm/model_executor/__init__.py
vllm/model_executor/__init__.py
+1
-2
vllm/model_executor/guided_decoding.py
vllm/model_executor/guided_decoding.py
+99
-0
vllm/model_executor/guided_logits_processors.py
vllm/model_executor/guided_logits_processors.py
+129
-0
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+23
-0
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+18
-16
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+5
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
...configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
+20
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
...configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
+24
-0
vllm/model_executor/layers/fused_moe/configs/README
vllm/model_executor/layers/fused_moe/configs/README
+10
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+66
-16
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+29
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+9
-7
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+210
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-5
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+13
-5
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
+28
-13
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+5
-5
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+15
-2
No files found.
vllm/lora/punica.py
View file @
a130cf33
...
...
@@ -87,7 +87,7 @@ def add_lora(y: torch.Tensor,
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default to avoid
# numerical in
n
acuracies that would otherwise happen
# numerical ina
c
curacies that would otherwise happen
# due to downcasting.
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
...
...
vllm/model_executor/__init__.py
View file @
a130cf33
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
,
get_model
__all__
=
[
"InputMetadata"
,
...
...
vllm/model_executor/guided_decoding.py
0 → 100644
View file @
a130cf33
import
asyncio
import
concurrent.futures
from
copy
import
copy
from
enum
import
Enum
from
functools
import
lru_cache
from
json
import
dumps
as
json_dumps
from
re
import
escape
as
regex_escape
from
typing
import
Union
,
Tuple
from
pydantic
import
BaseModel
from
vllm.entrypoints.openai.protocol
import
CompletionRequest
,
ChatCompletionRequest
from
vllm.model_executor.guided_logits_processors
import
JSONLogitsProcessor
,
RegexLogitsProcessor
class
GuidedDecodingMode
(
Enum
):
JSON
=
"json"
REGEX
=
"regex"
CHOICE
=
"choice"
global_thread_pool
=
None
# used for generating logits processor fsm
async
def
get_guided_decoding_logits_processor
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
],
tokenizer
)
->
Union
[
JSONLogitsProcessor
,
RegexLogitsProcessor
]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
global
global_thread_pool
guide
,
mode
=
_get_guide_and_mode
(
request
)
if
not
guide
:
return
None
if
global_thread_pool
is
None
:
global_thread_pool
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
loop
=
asyncio
.
get_running_loop
()
result
=
await
loop
.
run_in_executor
(
global_thread_pool
,
_get_cached_logits_processor
,
guide
,
tokenizer
,
mode
)
logits_processor
=
copy
(
result
)
# reset logits processor's internal state
logits_processor
.
init_state
()
return
logits_processor
def
_get_guide_and_mode
(
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
)
->
Tuple
[
str
,
GuidedDecodingMode
]:
if
request
.
guided_json
:
if
not
isinstance
(
request
.
guided_json
,
(
str
,
dict
,
BaseModel
)):
raise
TypeError
(
"JSON schema must be str, dict, or BaseModel"
)
json
=
request
.
guided_json
if
isinstance
(
json
,
dict
):
# turn dict into hashable string
json
=
json_dumps
(
json
,
sort_keys
=
True
)
elif
isinstance
(
json
,
BaseModel
):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json
=
str
(
json
.
__signature__
)
return
json
,
GuidedDecodingMode
.
JSON
elif
request
.
guided_regex
:
if
not
isinstance
(
request
.
guided_regex
,
str
):
raise
TypeError
(
"Regex must be string"
)
return
request
.
guided_regex
,
GuidedDecodingMode
.
REGEX
elif
request
.
guided_choice
:
if
not
isinstance
(
request
.
guided_choice
,
list
):
raise
TypeError
(
"Choices must be a list"
)
# choice just uses regex
choices
=
[
regex_escape
(
str
(
choice
))
for
choice
in
request
.
guided_choice
]
choices_regex
=
"("
+
"|"
.
join
(
choices
)
+
")"
return
choices_regex
,
GuidedDecodingMode
.
CHOICE
else
:
return
None
,
None
@
lru_cache
(
maxsize
=
32
)
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
,
mode
:
GuidedDecodingMode
):
if
mode
==
GuidedDecodingMode
.
JSON
:
return
JSONLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
else
:
raise
ValueError
(
f
"Unknown guided decoding mode
{
mode
}
"
)
vllm/model_executor/guided_logits_processors.py
0 → 100644
View file @
a130cf33
# Copyright 2024- the Outlines developers
# This file is adapted from
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
math
from
collections
import
defaultdict
from
typing
import
Union
,
DefaultDict
,
Dict
,
List
,
Optional
import
torch
from
pydantic
import
BaseModel
from
outlines.fsm.fsm
import
RegexFSM
from
outlines.fsm.json_schema
import
build_regex_from_schema
class
RegexLogitsProcessor
:
def
__init__
(
self
,
regex_string
:
str
,
tokenizer
):
"""Compile the FSM that drives the regex-structured generation.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
"""
tokenizer
=
self
.
adapt_tokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex_string
,
tokenizer
)
self
.
fsm
=
fsm
def
init_state
(
self
):
"""Initialize the FSM states."""
self
.
fsm_state
:
DefaultDict
[
int
,
int
]
=
defaultdict
(
int
)
def
__call__
(
self
,
input_ids
:
List
[
int
],
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id
=
hash
(
tuple
(
input_ids
))
if
len
(
input_ids
)
==
0
:
self
.
init_state
()
else
:
last_token
=
input_ids
[
-
1
]
last_seq_id
=
hash
(
tuple
(
input_ids
[:
-
1
]))
self
.
fsm_state
[
seq_id
]
=
self
.
fsm
.
next_state
(
self
.
fsm_state
[
last_seq_id
],
last_token
)
allowed_tokens
=
self
.
fsm
.
allowed_token_ids
(
self
.
fsm_state
[
seq_id
])
mask
=
torch
.
full
((
scores
.
shape
[
-
1
],
),
-
math
.
inf
,
device
=
scores
.
device
)
mask
[
allowed_tokens
]
=
0
scores
.
add_
(
mask
)
return
scores
def
adapt_tokenizer
(
self
,
tokenizer
):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer
.
vocabulary
=
tokenizer
.
get_vocab
()
tokenizer
.
special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
def
convert_token_to_string
(
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
tokenizer
.
convert_tokens_to_string
([
token
])
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
tokenizer
.
convert_token_to_string
=
convert_token_to_string
return
tokenizer
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
def
__init__
(
self
,
schema
:
Union
[
str
,
Dict
,
BaseModel
],
tokenizer
,
whitespace_pattern
:
Optional
[
str
]
=
None
):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
tokenizer
The model's tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[
\n
]?"`
"""
if
isinstance
(
schema
,
type
(
BaseModel
)):
schema_str
=
json
.
dumps
(
schema
.
model_json_schema
())
elif
isinstance
(
schema
,
Dict
):
schema_str
=
json
.
dumps
(
schema
)
elif
isinstance
(
schema
,
str
):
schema_str
=
schema
else
:
raise
ValueError
(
f
"Cannot parse schema
{
schema
}
. The schema must be either "
+
"a Pydantic object, a dictionary or a string that contains the JSON "
+
"Schema specification"
)
regex_string
=
build_regex_from_schema
(
schema_str
,
whitespace_pattern
)
super
().
__init__
(
regex_string
,
tokenizer
)
vllm/model_executor/layers/activation.py
View file @
a130cf33
...
...
@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
return
out
class
GeluAndMul
(
nn
.
Module
):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
gelu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
gelu_and_mul
(
out
,
x
)
return
out
class
NewGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/attention.py
View file @
a130cf33
...
...
@@ -137,25 +137,27 @@ class PagedAttention(nn.Module):
)
if
input_metadata
.
is_prompt
:
# Prompt run.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# normal attention
if
(
key_cache
is
None
or
value_cache
is
None
or
input_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
...
...
vllm/model_executor/layers/fused_moe/__init__.py
0 → 100644
View file @
a130cf33
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
__all__
=
[
"fused_moe"
,
]
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
0 → 100644
View file @
a130cf33
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
7
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"192"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
6
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
a130cf33
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"80"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"192"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"200"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"208"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"216"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"224"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/README
0 → 100644
View file @
a130cf33
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
vllm/model_executor/layers/fused_moe.py
→
vllm/model_executor/layers/fused_moe
/fused_moe
.py
View file @
a130cf33
"""Fused MoE kernel."""
import
functools
import
json
import
os
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
vllm._C
import
ops
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
fused_moe_kernel
(
...
...
@@ -129,7 +137,7 @@ def fused_moe_kernel(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
(
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
)
:
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
:
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
...
...
@@ -177,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
dict
):
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
])
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -210,6 +219,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs directory
device_name
=
torch
.
cuda
.
get_device_name
().
replace
(
" "
,
"_"
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}
.json"
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
f
"Using configuration from
{
config_file_path
}
for MoE layer."
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default configuration
return
None
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -218,6 +255,7 @@ def fused_moe(
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...
...
@@ -230,6 +268,7 @@ def fused_moe(
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
...
...
@@ -279,20 +318,31 @@ def fused_moe(
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
topk_ids
.
numel
()
<=
w1
.
shape
[
0
]:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
])
if
configs
:
# If an optimal configuration map has been found, look up the optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
...
...
vllm/model_executor/layers/linear.py
View file @
a130cf33
...
...
@@ -17,6 +17,14 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
marlin_tile_size
=
getattr
(
param
,
"marlin_tile_size"
,
None
)
if
marlin_tile_size
is
None
:
return
shard_size
,
shard_offset
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
class
LinearMethodBase
(
ABC
):
"""Base class for different (maybe quantized) linear methods."""
...
...
@@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
...
...
@@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
...
...
@@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id
:
Optional
[
str
]
=
None
):
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
if
output_dim
is
None
:
...
...
@@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
...
...
@@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
a130cf33
...
...
@@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"marlin"
:
MarlinConfig
,
}
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
a130cf33
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
fractions
import
Fraction
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
pack_factor
=
32
//
self
.
weight_bits
# exllama kernel v1 only supports 4 bit
if
self
.
weight_bits
!=
4
:
self
.
pack_factor
=
Fraction
(
32
,
self
.
weight_bits
)
if
self
.
weight_bits
not
in
[
2
,
3
,
4
,
8
]:
raise
ValueError
(
"Currently, only
4
-bit weight quantization is supported for "
"Currently, only
2/3/4/8
-bit weight quantization is supported for "
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
def
__repr__
(
self
)
->
str
:
...
...
@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
.
numerator
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
...
...
@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
])
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
self
.
quant_config
.
weight_bits
)
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
)
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
self
.
quant_config
.
weight_bits
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/marlin.py
0 → 100644
View file @
a130cf33
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm._C
import
ops
from
vllm.model_executor.layers.linear
import
LinearMethodBase
,
set_weight_attrs
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
class
MarlinConfig
(
QuantizationConfig
):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def
__init__
(
self
,
group_size
:
int
,
)
->
None
:
# Group size for the quantization.
self
.
group_size
=
group_size
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
raise
ValueError
(
"Currently, only group size 128 and -1 (channelwise) is supported for "
f
"Marlin, but got group_size of
{
self
.
group_size
}
"
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
4
# Tile size used by marlin kernels.
self
.
tile_size
=
16
# Min out_features dim
self
.
min_n_threads
=
64
# Min in_features dim
self
.
min_k_threads
=
128
# Max parallel problems to solve at once (improves large batch performance)
self
.
max_parallel
=
16
# Permutation length used by the marlin kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
f
"MarlinConfig(group_size=
{
self
.
group_size
}
"
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"MarlinConfig"
:
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
group_size
)
def
get_linear_method
(
self
)
->
"MarlinLinearMethod"
:
return
MarlinLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
MarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
MarlinConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
del
output_size
# Unused.
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition =
{
output_size_per_partition
}
is not divisible by min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition =
{
output_size_per_partition
}
is not divisible by pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
is not divisible by min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = f
{
input_size_per_partition
}
is not divisible by group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
# Determine if channelwise or not
input_groups
=
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
scales
=
Parameter
(
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"output_dim"
:
1
,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
return
{
"B"
:
qweight
,
"s"
:
scales
,
"workspace"
:
workspace
,
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
weights
[
"B"
]
scales
=
weights
[
"s"
]
workspace
=
weights
[
"workspace"
]
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
marlin_gemm
(
x_2d
,
qweight
,
scales
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/rotary_embedding.py
View file @
a130cf33
...
...
@@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int,
def
_yarn_linear_ramp_mask
(
low
:
float
,
high
:
float
,
dim
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
if
low
==
high
:
high
+=
0.001
# Prevent singularity
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
,
device
=
device
)
-
low
)
/
(
high
-
low
)
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
)
-
low
)
/
(
high
-
low
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
...
...
@@ -356,7 +354,6 @@ def get_rope(
elif
scaling_type
==
"yarn"
:
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
assert
max_position
==
original_max_position
*
scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
...
...
vllm/model_executor/layers/sampler.py
View file @
a130cf33
...
...
@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.utils
import
is_neuron
class
Sampler
(
nn
.
Module
):
...
...
@@ -32,6 +33,8 @@ class Sampler(nn.Module):
org_vocab_size
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
vocab_size
=
vocab_size
# Transformers-neuronx generate outputs as logits directly.
self
.
logits_as_hidden_states
=
is_neuron
()
# original vocabulary size (without LoRA).
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
...
...
@@ -55,10 +58,14 @@ class Sampler(nn.Module):
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
SamplerOutput
]:
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
if
self
.
logits_as_hidden_states
:
logits
=
hidden_states
else
:
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
...
...
@@ -395,7 +402,8 @@ def _sample(
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
.
long
()],
dim
=-
1
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
...
...
@@ -407,7 +415,7 @@ def _sample(
"generators"
:
sampling_metadata
.
generators
,
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
,
**
seeded_args
)
probs
[
sample_indices
.
long
()
],
max_best_of
,
**
seeded_args
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
...
...
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
View file @
a130cf33
...
...
@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
...
@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
...
...
@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
...
...
@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
cur_
kv_
head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
...
...
@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_
kv_
head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_
kv_
head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
...
...
@@ -537,7 +549,7 @@ if triton.__version__ >= "2.1.0":
alibi_start_q
=
tl
.
arange
(
0
,
BLOCK_M
)
+
block_start_loc
+
cur_batch_ctx_len
alibi_start_k
=
cur_batch_ctx_len
# # init debuger
# # init debug
g
er
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
...
...
@@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0":
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
num_queries_per_kv
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
...
...
@@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
...
...
@@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0":
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv
=
num_queries_per_kv
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
...
...
vllm/model_executor/model_loader.py
View file @
a130cf33
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Optional
,
Type
from
typing
import
Type
import
torch
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
ModelConfig
,
LoRAConfig
from
vllm.config
import
DeviceConfig
,
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
...
...
@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAC
onfig
]
=
None
)
->
nn
.
Module
:
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
lora_config
=
kwargs
.
get
(
"lora_c
onfig
"
,
None
)
model_class
=
_get_model_architecture
(
model_config
)
# Get the (maybe quantized) linear method.
...
...
vllm/model_executor/models/__init__.py
View file @
a130cf33
...
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Type
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
is_neuron
logger
=
init_logger
(
__name__
)
...
...
@@ -30,7 +30,7 @@ _MODELS = {
"LlamaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"
mistral"
,
"Mistral
ForCausalLM"
),
"MistralForCausalLM"
:
(
"
llama"
,
"Llama
ForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
"QuantMixtralForCausalLM"
:
(
"mixtral_quant"
,
"MixtralForCausalLM"
),
# transformers's mpt class has lower case
...
...
@@ -38,11 +38,14 @@ _MODELS = {
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OLMoForCausalLM"
:
(
"olmo"
,
"OLMoForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"Starcoder2ForCausalLM"
:
(
"starcoder2"
,
"Starcoder2ForCausalLM"
),
}
# Models not supported by ROCm.
...
...
@@ -59,6 +62,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention"
,
}
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS
=
{
"LlamaForCausalLM"
:
"neuron.llama"
}
class
ModelRegistry
:
...
...
@@ -75,8 +81,15 @@ class ModelRegistry:
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is partially supported "
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
elif
is_neuron
():
if
model_arch
not
in
_NEURON_SUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture
{
model_arch
}
is not supported by "
"Neuron for now."
)
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
if
is_neuron
():
module_name
=
_NEURON_SUPPORTED_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
f
"vllm.model_executor.models.
{
module_name
}
"
)
return
getattr
(
module
,
model_cls_name
,
None
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment