Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ec5e299c
Commit
ec5e299c
authored
Feb 21, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.3' into v0.7.3-dev
parents
47bd229c
ed6e9075
Changes
521
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
913 additions
and
221 deletions
+913
-221
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+11
-5
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+7
-4
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+1
-1
vllm/entrypoints/openai/serving_transcription.py
vllm/entrypoints/openai/serving_transcription.py
+306
-0
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+1
-1
vllm/envs.py
vllm/envs.py
+23
-2
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+12
-3
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+4
-0
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+12
-5
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+6
-1
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+100
-14
vllm/inputs/registry.py
vllm/inputs/registry.py
+32
-48
vllm/logits_process.py
vllm/logits_process.py
+1
-1
vllm/lora/layers.py
vllm/lora/layers.py
+6
-4
vllm/lora/models.py
vllm/lora/models.py
+7
-1
vllm/lora/ops/triton_ops/kernel_utils.py
vllm/lora/ops/triton_ops/kernel_utils.py
+243
-0
vllm/lora/ops/triton_ops/sgmv_expand.py
vllm/lora/ops/triton_ops/sgmv_expand.py
+44
-73
vllm/lora/ops/triton_ops/sgmv_shrink.py
vllm/lora/ops/triton_ops/sgmv_shrink.py
+40
-56
vllm/lora/punica_wrapper/punica_base.py
vllm/lora/punica_wrapper/punica_base.py
+1
-1
vllm/lora/punica_wrapper/punica_hpu.py
vllm/lora/punica_wrapper/punica_hpu.py
+56
-1
No files found.
Too many changes to show.
To preserve performance only
521 of 521+
files are displayed.
Plain diff
Email patch
vllm/entrypoints/openai/serving_chat.py
View file @
ec5e299c
...
...
@@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
MistralToolCall
)
from
vllm.logger
import
init_logger
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
maybe_serialize_tool_calls
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
logger
=
init_logger
(
__name__
)
...
...
@@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
create_error_response
(
"tool_choice =
\"
required
\"
is not supported!"
)
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
if
(
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
...
...
@@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
elif
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
:
tool_call_class
=
MistralToolCall
if
isinstance
(
tokenizer
,
MistralTokenizer
)
else
ToolCall
message
=
ChatMessage
(
role
=
role
,
content
=
""
,
tool_calls
=
[
T
ool
C
all
(
function
=
FunctionCall
(
t
ool
_c
all
_class
(
function
=
FunctionCall
(
name
=
request
.
tool_choice
.
function
.
name
,
arguments
=
output
.
text
))
])
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
ec5e299c
...
...
@@ -31,7 +31,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ErrorResponse
,
RerankRequest
,
ScoreRequest
,
TokenizeChatRequest
,
TokenizeCompletionRequest
)
TokenizeCompletionRequest
,
TranscriptionRequest
)
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
# yapf: enable
...
...
@@ -57,7 +58,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
ChatLikeRequest
=
Union
[
ChatCompletionRequest
,
EmbeddingChatRequest
,
TokenizeChatRequest
]
AnyRequest
=
Union
[
CompletionLikeRequest
,
ChatLikeRequest
]
AnyRequest
=
Union
[
CompletionLikeRequest
,
ChatLikeRequest
,
TranscriptionRequest
]
class
TextTokensPrompt
(
TypedDict
):
...
...
@@ -400,8 +402,7 @@ class OpenAIServing:
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
request_prompt
:
Union
[
str
,
List
[
int
]]
is_mistral_tokenizer
=
isinstance
(
tokenizer
,
MistralTokenizer
)
if
is_mistral_tokenizer
:
if
isinstance
(
tokenizer
,
MistralTokenizer
):
request_prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
messages
,
...
...
@@ -450,6 +451,8 @@ class OpenAIServing:
prompt_token_ids
=
prompt_inputs
[
"prompt_token_ids"
])
if
mm_data
is
not
None
:
engine_prompt
[
"multi_modal_data"
]
=
mm_data
if
request
.
mm_processor_kwargs
is
not
None
:
engine_prompt
[
"mm_processor_kwargs"
]
=
request
.
mm_processor_kwargs
return
conversation
,
[
request_prompt
],
[
engine_prompt
]
...
...
vllm/entrypoints/openai/serving_score.py
View file @
ec5e299c
...
...
@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
tokenize_async
=
make_async
(
tokenizer
.
__call__
,
executor
=
self
.
_tokenizer_executor
)
prompt_inputs
=
await
tokenize_async
(
text
=
q
,
prompt_inputs
=
await
tokenize_async
(
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
...
...
vllm/entrypoints/openai/serving_transcription.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
io
from
typing
import
AsyncGenerator
,
Optional
,
Union
,
cast
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
RequestResponseMetadata
,
TranscriptionRequest
,
TranscriptionResponse
,
TranscriptionResponseVerbose
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.inputs.data
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.utils
import
PlaceholderModule
try
:
import
librosa
except
ImportError
:
librosa
=
PlaceholderModule
(
"librosa"
)
# type: ignore[assignment]
logger
=
init_logger
(
__name__
)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
# TODO these configs should live somewhere with the model so we can support
# additional ones
ISO639_1_SUPPORTED_LANGS
=
{
"af"
:
"Afrikaans"
,
"ar"
:
"Arabic"
,
"hy"
:
"Armenian"
,
"az"
:
"Azerbaijani"
,
"be"
:
"Belarusian"
,
"bs"
:
"Bosnian"
,
"bg"
:
"Bulgarian"
,
"ca"
:
"Catalan"
,
"zh"
:
"Chinese"
,
"hr"
:
"Croatian"
,
"cs"
:
"Czech"
,
"da"
:
"Danish"
,
"nl"
:
"Dutch"
,
"en"
:
"English"
,
"et"
:
"Estonian"
,
"fi"
:
"Finnish"
,
"fr"
:
"French"
,
"gl"
:
"Galician"
,
"de"
:
"German"
,
"el"
:
"Greek"
,
"he"
:
"Hebrew"
,
"hi"
:
"Hindi"
,
"hu"
:
"Hungarian"
,
"is"
:
"Icelandic"
,
"id"
:
"Indonesian"
,
"it"
:
"Italian"
,
"ja"
:
"Japanese"
,
"kn"
:
"Kannada"
,
"kk"
:
"Kazakh"
,
"ko"
:
"Korean"
,
"lv"
:
"Latvian"
,
"lt"
:
"Lithuanian"
,
"mk"
:
"Macedonian"
,
"ms"
:
"Malay"
,
"mr"
:
"Marathi"
,
"mi"
:
"Maori"
,
"ne"
:
"Nepali"
,
"no"
:
"Norwegian"
,
"fa"
:
"Persian"
,
"pl"
:
"Polish"
,
"pt"
:
"Portuguese"
,
"ro"
:
"Romanian"
,
"ru"
:
"Russian"
,
"sr"
:
"Serbian"
,
"sk"
:
"Slovak"
,
"sl"
:
"Slovenian"
,
"es"
:
"Spanish"
,
"sw"
:
"Swahili"
,
"sv"
:
"Swedish"
,
"tl"
:
"Tagalog"
,
"ta"
:
"Tamil"
,
"th"
:
"Thai"
,
"tr"
:
"Turkish"
,
"uk"
:
"Ukrainian"
,
"ur"
:
"Urdu"
,
"vi"
:
"Vietnamese"
,
"cy"
:
"Welsh"
}
ISO639_1_OTHER_LANGS
=
{
"lo"
:
"Lao"
,
"jw"
:
"Javanese"
,
"tk"
:
"Turkmen"
,
"yi"
:
"Yiddish"
,
"so"
:
"Somali"
,
"bn"
:
"Bengali"
,
"nn"
:
"Norwegian Nynorsk"
,
"si"
:
"Sinhala"
,
"yo"
:
"Yoruba"
,
"sa"
:
"Sanskrit"
,
"mi"
:
"Māori"
,
"fo"
:
"Faroese"
,
# codespell:ignore
"mt"
:
"Maltese"
,
"tg"
:
"Tajik"
,
"mg"
:
"Malagasy"
,
"haw"
:
"Hawaiian"
,
"km"
:
"Khmer"
,
"br"
:
"Breton"
,
"ps"
:
"Pashto"
,
"ln"
:
"Lingala"
,
"la"
:
"Latin"
,
"ml"
:
"Malayalam"
,
"sq"
:
"Albanian"
,
"su"
:
"Sundanese"
,
"eu"
:
"Basque"
,
"ka"
:
"Georgian"
,
"uz"
:
"Uzbek"
,
"sn"
:
"Shona"
,
"ht"
:
"Haitian"
,
"as"
:
"Assamese"
,
"mn"
:
"Mongolian"
,
"te"
:
"Telugu"
,
"pa"
:
"Panjabi"
,
"tt"
:
"Tatar"
,
"gu"
:
"Gujarati"
,
"oc"
:
"Occitan"
,
"ha"
:
"Hausa"
,
"ba"
:
"Bashkir"
,
"my"
:
"Burmese"
,
"sd"
:
"Sindhi"
,
"am"
:
"Amharic"
,
"lb"
:
"Luxembourgish"
,
"bo"
:
"Tibetan"
}
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB
=
25
# TODO get from processor.feature_extractor.chunk_length
MAX_AUDIO_CLIP_DURATION_S
=
30
class
OpenAIServingTranscription
(
OpenAIServing
):
def
__init__
(
self
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
Optional
[
RequestLogger
],
return_tokens_as_token_ids
:
bool
=
False
,
):
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
models
=
models
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
)
diff_sampling_param
=
self
.
model_config
.
get_diff_sampling_param
()
if
diff_sampling_param
:
logger
.
info
(
"Overwriting default completion sampling param with: %s"
,
diff_sampling_param
)
async
def
_preprocess_transcription
(
self
,
request
:
TranscriptionRequest
,
audio_data
:
bytes
,
)
->
PromptType
:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang_token
=
f
"<|
{
request
.
language
}
|>"
if
request
.
language
else
"<|en|>"
if
request
.
language
:
if
request
.
language
in
ISO639_1_SUPPORTED_LANGS
:
pass
elif
request
.
language
in
ISO639_1_OTHER_LANGS
:
logger
.
warning
(
"The selected language %s has limited accuracy with"
" reported WER>=0.5. Results may be less accurate "
"for this choice."
,
request
.
language
)
else
:
raise
ValueError
(
f
"Unsupported language:
{
request
.
language
}
."
"Language should be one of:"
+
f
"
{
list
(
ISO639_1_SUPPORTED_LANGS
.
values
())
}
"
+
f
"or
{
list
(
ISO639_1_OTHER_LANGS
.
values
())
}
"
)
if
len
(
audio_data
)
/
1024
**
2
>
MAX_AUDIO_CLIP_FILESIZE_MB
:
raise
ValueError
(
"Maximum file size exceeded."
)
with
io
.
BytesIO
(
audio_data
)
as
bytes_
:
y
,
sr
=
librosa
.
load
(
bytes_
)
if
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
>
MAX_AUDIO_CLIP_DURATION_S
:
raise
ValueError
(
f
"Maximum clip duration (
{
MAX_AUDIO_CLIP_DURATION_S
}
s) "
"exceeded."
)
prompt
=
{
"encoder_prompt"
:
{
"prompt"
:
""
,
"multi_modal_data"
:
{
"audio"
:
(
y
,
sr
),
},
},
"decoder_prompt"
:
f
"<|startoftranscript|>
{
lang_token
}
<|transcribe|><|notimestamps|>
{
request
.
prompt
}
"
}
return
cast
(
PromptType
,
prompt
)
# TODO (varun) : Make verbose response work !
async
def
create_transcription
(
self
,
audio_data
:
bytes
,
request
:
TranscriptionRequest
,
raw_request
:
Request
)
->
Union
[
TranscriptionResponse
,
TranscriptionResponseVerbose
,
ErrorResponse
]:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if
self
.
engine_client
.
errored
:
raise
self
.
engine_client
.
dead_error
if
request
.
response_format
not
in
[
'text'
,
'json'
]:
return
self
.
create_error_response
(
"Currently only support response_format `text` or `json`"
)
# TODO cmpl->transcription?
request_id
=
f
"cmpl-
{
self
.
_base_request_id
(
raw_request
)
}
"
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
if
lora_request
:
return
self
.
create_error_response
(
"Currently do not support LoRA for Transcription."
)
if
prompt_adapter_request
:
return
self
.
create_error_response
(
"Currently do not support PromptAdapter for Transcription."
)
prompt
=
await
self
.
_preprocess_transcription
(
request
=
request
,
audio_data
=
audio_data
,
)
except
ValueError
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
Optional
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
None
try
:
# TODO(rob): subtract len of tokenized prompt.
default_max_tokens
=
self
.
model_config
.
max_model_len
default_params
=
self
.
model_config
.
get_diff_sampling_param
()
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
default_params
)
self
.
_log_inputs
(
request_id
,
prompt
[
'decoder_prompt'
],
# type: ignore
params
=
sampling_params
,
lora_request
=
None
,
prompt_adapter_request
=
None
)
result_generator
=
self
.
engine_client
.
generate
(
prompt
,
sampling_params
,
request_id
,
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
# TODO(rob): figure out a way to pipe streaming in.
# Non-streaming response.
try
:
assert
result_generator
is
not
None
async
for
op
in
result_generator
:
result
=
op
return
TranscriptionResponse
(
text
=
result
.
outputs
[
0
].
text
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
ec5e299c
...
...
@@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
@
staticmethod
def
generate_random_id
():
# Mistral Tool Call Ids must be alphanumeric with a
maximum
length of 9.
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return
""
.
join
(
choices
(
ALPHANUMERIC
,
k
=
9
))
...
...
vllm/envs.py
View file @
ec5e299c
...
...
@@ -60,6 +60,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
VLLM_VIDEO_FETCH_TIMEOUT
:
int
=
30
VLLM_AUDIO_FETCH_TIMEOUT
:
int
=
10
VLLM_MM_INPUT_CACHE_SIZE
:
int
=
256
VLLM_TARGET_DEVICE
:
str
=
"cuda"
MAX_JOBS
:
Optional
[
str
]
=
None
NVCC_THREADS
:
Optional
[
str
]
=
None
...
...
@@ -93,6 +94,8 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
bool
=
False
VLLM_RAY_PER_WORKER_GPUS
:
float
=
1.0
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_CUDART_SO_PATH
:
Optional
[
str
]
=
None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
:
bool
=
True
def
get_default_cache_root
():
...
...
@@ -431,15 +434,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
# Timeout for fetching videos when serving multimodal models
# Default is
15
seconds
# Default is
30
seconds
"VLLM_VIDEO_FETCH_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_VIDEO_FETCH_TIMEOUT"
,
"
15
"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_VIDEO_FETCH_TIMEOUT"
,
"
30
"
)),
# Timeout for fetching audio when serving multimodal models
# Default is 10 seconds
"VLLM_AUDIO_FETCH_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_AUDIO_FETCH_TIMEOUT"
,
"10"
)),
# Cache size for multimodal feature/input cache for multimodal models
# in unit of number of multimodal data items (e.g. image, video, audio).
# Default is 256 multimodal data items.
"VLLM_MM_INPUT_CACHE_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MM_INPUT_CACHE_SIZE"
,
"256"
)),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
...
...
@@ -608,6 +617,18 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CUDA_MEM_ALIGN_KV_CACHE"
,
"1"
))),
# In some system, find_loaded_library() may not work. So we allow users to
# specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_CUDART_SO_PATH"
,
None
),
# Contiguous cache fetching to avoid using costly gather operation on
# Gaudi3. This is only applicable to HPU contiguous cache. If set to true,
# contiguous cache fetch will be used.
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_CONTIGUOUS_PA"
,
"true"
).
lower
()
in
(
"1"
,
"true"
),
}
# end-env-vars-definition
...
...
vllm/executor/executor_base.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
time
from
abc
import
ABC
,
abstractmethod
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
...
...
@@ -8,11 +9,11 @@ from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
import
vllm.platforms
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.utils
import
make_async
...
...
@@ -108,8 +109,8 @@ class ExecutorBase(ABC):
"""
# NOTE: This is logged in the executor because there can be >1 workers.
logger
.
info
(
"# %s blocks: %d, # CPU blocks: %d"
,
current_platform
.
d
ispatch_key
,
num_gpu_blocks
,
num_cpu_blocks
)
vllm
.
platforms
.
current_platform
.
d
evice_name
,
num_gpu_blocks
,
num_cpu_blocks
)
max_concurrency
=
(
num_gpu_blocks
*
self
.
cache_config
.
block_size
/
self
.
model_config
.
max_model_len
)
logger
.
info
(
"Maximum concurrency for %s tokens per request: %.2fx"
,
...
...
@@ -200,15 +201,23 @@ class ExecutorBase(ABC):
if
self
.
is_sleeping
:
logger
.
warning
(
"Executor is already sleeping."
)
return
time_before_sleep
=
time
.
perf_counter
()
self
.
collective_rpc
(
"sleep"
,
kwargs
=
dict
(
level
=
level
))
time_after_sleep
=
time
.
perf_counter
()
self
.
is_sleeping
=
True
logger
.
info
(
"It took %.6f seconds to fall asleep."
,
time_after_sleep
-
time_before_sleep
)
def
wake_up
(
self
):
if
not
self
.
is_sleeping
:
logger
.
warning
(
"Executor is not sleeping."
)
return
time_before_wakeup
=
time
.
perf_counter
()
self
.
collective_rpc
(
"wake_up"
)
time_after_wakeup
=
time
.
perf_counter
()
self
.
is_sleeping
=
False
logger
.
info
(
"It took %.6f seconds to wake up."
,
time_after_wakeup
-
time_before_wakeup
)
def
save_sharded_state
(
self
,
...
...
vllm/executor/ray_distributed_executor.py
View file @
ec5e299c
...
...
@@ -101,6 +101,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
self
.
driver_worker
.
execute_method
)
def
shutdown
(
self
)
->
None
:
logger
.
info
(
"Shutting down Ray distributed executor. If you see error log "
"from logging.cc regarding SIGTERM received, please ignore because "
"this is the expected termination process in Ray."
)
if
hasattr
(
self
,
"forward_dag"
)
and
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
import
ray
...
...
vllm/executor/ray_utils.py
View file @
ec5e299c
...
...
@@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import
msgspec
import
vllm.platforms
from
vllm.config
import
ParallelConfig
from
vllm.executor.msgspec_utils
import
decode_hook
,
encode_hook
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.utils
import
get_ip
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -35,7 +35,7 @@ try:
class
RayWorkerWrapper
(
WorkerWrapperBase
):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
laz
l
iy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
lazi
l
y initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
...
...
@@ -54,10 +54,10 @@ try:
def
get_node_and_gpu_ids
(
self
)
->
Tuple
[
str
,
List
[
int
]]:
node_id
=
ray
.
get_runtime_context
().
get_node_id
()
device_key
=
current_platform
.
ray_device_key
device_key
=
vllm
.
platforms
.
current_platform
.
ray_device_key
if
not
device_key
:
raise
RuntimeError
(
"current platform %s does not support ray."
,
current_platform
.
device_name
)
vllm
.
platforms
.
current_platform
.
device_name
)
gpu_ids
=
ray
.
get_runtime_context
().
get_accelerator_ids
(
)[
device_key
]
return
node_id
,
gpu_ids
...
...
@@ -118,7 +118,14 @@ try:
)
->
"ModelRunnerOutput"
:
self
.
setup_device_if_necessary
()
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
)
if
isinstance
(
scheduler_output
,
tuple
):
scheduler_output
,
intermediate_tensors
=
scheduler_output
else
:
scheduler_output
,
intermediate_tensors
=
scheduler_output
,
None
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
isinstance
(
output
,
IntermediateTensors
):
output
=
scheduler_output
,
output
return
output
def
override_env_vars
(
self
,
vars
:
Dict
[
str
,
str
]):
...
...
vllm/executor/uniproc_executor.py
View file @
ec5e299c
...
...
@@ -28,6 +28,11 @@ class UniProcExecutor(ExecutorBase):
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
local_rank
=
0
# set local rank as the device index if specified
device_info
=
self
.
vllm_config
.
device_config
.
device
.
__str__
().
split
(
":"
)
if
len
(
device_info
)
>
1
:
local_rank
=
int
(
device_info
[
1
])
rank
=
0
kwargs
=
dict
(
vllm_config
=
self
.
vllm_config
,
...
...
@@ -101,7 +106,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
# - MASTER_PORT
distributed_init_method
=
"env://"
rank
=
int
(
os
.
environ
[
"RANK"
])
local_rank
=
rank
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
is_driver_worker
=
True
kwargs
=
dict
(
vllm_config
=
self
.
vllm_config
,
...
...
vllm/inputs/preprocess.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
asyncio
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
,
cast
from
typing_extensions
import
assert_never
...
...
@@ -9,7 +9,8 @@ from vllm.config import ModelConfig
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalInputs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
...
...
@@ -254,14 +255,18 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
if
not
self
.
tokenizer
:
tokenizer
=
None
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
)
if
isinstance
(
prompt
,
list
):
prompt
=
tokenizer
.
decode
(
prompt
)
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
...
...
@@ -275,9 +280,15 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
],
)
->
MultiModalInputs
:
"""Async version of :meth:`_process_multimodal`."""
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
if
not
self
.
tokenizer
:
tokenizer
=
None
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
)
...
...
@@ -485,6 +496,51 @@ class InputPreprocessor:
decoder
=
decoder_inputs
,
)
def
_separate_enc_dec_inputs_from_mm_processor_outputs
(
self
,
inputs
:
SingletonInputs
,
decoder_inputs_to_override
:
Optional
[
SingletonInputs
]
=
None
,
)
->
Tuple
[
SingletonInputs
,
SingletonInputs
]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs
:
SingletonInputs
decoder_inputs
:
SingletonInputs
if
inputs
[
"type"
]
==
"multimodal"
:
# Multimodal data inputs
assert
(
"encoder_prompt"
in
inputs
and
"encoder_prompt_token_ids"
in
inputs
)
inputs
=
cast
(
MultiModalEncDecInputs
,
inputs
)
encoder_inputs
=
token_inputs
(
prompt
=
inputs
[
"encoder_prompt"
],
prompt_token_ids
=
inputs
[
"encoder_prompt_token_ids"
],
)
if
decoder_inputs_to_override
is
not
None
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
decoder_inputs_to_override
.
get
(
"prompt"
,
""
),
prompt_token_ids
=
decoder_inputs_to_override
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
else
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
inputs
[
"prompt"
],
prompt_token_ids
=
inputs
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
elif
inputs
[
"type"
]
==
"token"
:
# Text-only inputs
encoder_inputs
=
token_inputs
(
prompt
=
""
,
prompt_token_ids
=
[])
decoder_inputs
=
decoder_inputs_to_override
or
inputs
else
:
assert_never
(
inputs
)
# type: ignore[arg-type]
return
encoder_inputs
,
decoder_inputs
def
_process_encoder_decoder_prompt
(
self
,
prompt
:
PromptType
,
...
...
@@ -529,7 +585,6 @@ class InputPreprocessor:
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_inputs
=
None
else
:
...
...
@@ -537,13 +592,28 @@ class InputPreprocessor:
decoder_input
,
request_id
=
request_id
,
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
encoder_
inputs
=
self
.
_prompt_to_llm_inputs
(
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
inputs
))
else
:
encoder_inputs
=
inputs
decoder_inputs
=
None
decoder_inputs
=
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
...
...
@@ -573,13 +643,29 @@ class InputPreprocessor:
encoder_inputs
,
decoder_inputs
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
encoder_
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
request_id
=
request_id
,
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
inputs
))
else
:
encoder_inputs
=
inputs
decoder_inputs
=
None
decoder_inputs
=
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
...
...
vllm/inputs/registry.py
View file @
ec5e299c
...
...
@@ -11,8 +11,9 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from
typing_extensions
import
TypeVar
,
assert_never
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
cached_tokenizer_from_config
)
from
vllm.utils
import
(
ClassRegistry
,
get_allowed_kwarg_only_overrides
,
resolve_mm_processor_kwargs
)
...
...
@@ -27,19 +28,9 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
P
=
TypeVar
(
"P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
class
HashableDict
(
dict
):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def
__hash__
(
self
)
->
int
:
# type: ignore[override]
return
hash
(
frozenset
(
self
.
items
()))
_T
=
TypeVar
(
"_T"
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
_P
=
TypeVar
(
"_P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -54,9 +45,9 @@ class InputContext:
def
get_hf_config
(
self
,
typ
:
Union
[
type
[
C
],
tuple
[
type
[
C
],
...]]
=
PretrainedConfig
,
typ
:
Union
[
type
[
_
C
],
tuple
[
type
[
_
C
],
...]]
=
PretrainedConfig
,
/
,
)
->
C
:
)
->
_
C
:
"""
Get the HuggingFace configuration
(:class:`transformers.PretrainedConfig`) of the model,
...
...
@@ -94,10 +85,10 @@ class InputContext:
def
get_hf_processor
(
self
,
typ
:
Union
[
type
[
P
],
tuple
[
type
[
P
],
...]]
=
ProcessorMixin
,
typ
:
Union
[
type
[
_
P
],
tuple
[
type
[
_
P
],
...]]
=
ProcessorMixin
,
/
,
**
kwargs
:
object
,
)
->
P
:
)
->
_
P
:
"""
Get the HuggingFace processor
(:class:`transformers.ProcessorMixin`) of the model,
...
...
@@ -106,33 +97,29 @@ class InputContext:
Raises:
TypeError: If the processor is not of the specified type.
"""
return
cached_processor_from_config
(
self
.
model_config
,
processor_cls
=
typ
,
**
kwargs
,
)
def
init_processor
(
self
,
typ
:
type
[
_T
],
/
,
**
kwargs
:
object
,
)
->
_T
:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
base_kwargs
=
self
.
model_config
.
mm_processor_kwargs
if
base_kwargs
is
None
:
base_kwargs
=
{}
merged_kwargs
=
{
**
base_kwargs
,
**
kwargs
}
if
isinstance
(
typ
,
type
):
merged_kwargs
[
"processor_cls"
]
=
typ
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for
key
,
value
in
merged_kwargs
.
items
():
if
isinstance
(
value
,
dict
):
merged_kwargs
[
key
]
=
HashableDict
(
value
)
hf_processor
=
cached_get_processor
(
self
.
model_config
.
model
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
**
merged_kwargs
,
)
if
not
isinstance
(
hf_processor
,
typ
):
raise
TypeError
(
"Invalid type of HuggingFace processor. "
f
"Expected type:
{
typ
}
, but "
f
"found type:
{
type
(
hf_processor
)
}
"
)
return
hf_processor
return
typ
(
**
merged_kwargs
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -142,10 +129,10 @@ class InputProcessingContext(InputContext):
def
get_hf_processor
(
self
,
typ
:
Union
[
type
[
P
],
tuple
[
type
[
P
],
...]]
=
ProcessorMixin
,
typ
:
Union
[
type
[
_
P
],
tuple
[
type
[
_
P
],
...]]
=
ProcessorMixin
,
/
,
**
kwargs
:
object
,
)
->
P
:
)
->
_
P
:
return
super
().
get_hf_processor
(
typ
,
tokenizer
=
self
.
tokenizer
,
...
...
@@ -341,16 +328,13 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.profiling
import
MultiModalProfiler
from
vllm.multimodal.utils
import
cached_get_tokenizer
if
mm_registry
.
has_processor
(
model_config
):
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
tokenizer
=
cached_tokenizer_from_config
(
model_config
)
processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data
=
profiler
.
get_dummy_data
(
seq_len
)
dummy_data
=
profiler
.
get_dummy_data
(
seq_len
,
is_encoder_data
=
is_encoder_data
)
else
:
model_cls
,
_
=
get_model_architecture
(
model_config
)
if
is_encoder_data
:
...
...
vllm/logits_process.py
View file @
ec5e299c
...
...
@@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# Mistral tokenizers should not add special tokens
prompt_token_ids
=
tokenizer
.
encode
(
promp
t
=
prompt
)
prompt_token_ids
=
tokenizer
.
encode
(
tex
t
=
prompt
)
else
:
prompt_token_ids
=
tokenizer
.
encode
(
text
=
prompt
,
add_special_tokens
=
False
)
...
...
vllm/lora/layers.py
View file @
ec5e299c
...
...
@@ -16,8 +16,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_gather
)
tensor_model_parallel_all_reduce
)
from
vllm.distributed.utils
import
divide
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -1040,10 +1039,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
lm_head
.
linear
_method
.
apply
(
lm_head
,
hidden_states
)
logits
=
lm_head
.
quant
_method
.
apply
(
lm_head
,
hidden_states
)
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
tensor_model_parallel_gather
(
logits
)
# Gather logits for TP
logits
=
self
.
base_layer
.
_gather_logits
(
logits
)
if
logits
is
None
:
return
None
...
...
vllm/lora/models.py
View file @
ec5e299c
...
...
@@ -5,7 +5,8 @@ import math
import
os
import
re
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Type
,
Union
)
import
safetensors.torch
import
torch
...
...
@@ -622,12 +623,14 @@ class LoRAModelManager(AdapterModelManager):
def
_create_merged_loras_inplace
(
self
,
lora_model
:
LoRAModel
)
->
None
:
for
module_name
,
new_module_names
in
self
.
packed_modules
.
items
():
replacement_loras
:
List
[
Optional
[
LoRALayerWeights
]]
=
[]
replaced_module
:
Set
[
str
]
=
set
()
has_replacement
=
False
for
r
in
new_module_names
:
lora
=
lora_model
.
get_lora
(
r
)
replacement_loras
.
append
(
lora
)
if
lora
:
has_replacement
=
True
replaced_module
.
add
(
r
)
if
not
has_replacement
:
continue
for
i
in
range
(
len
(
replacement_loras
)):
...
...
@@ -636,6 +639,9 @@ class LoRAModelManager(AdapterModelManager):
replacement_loras
[
i
]
=
None
lora_model
.
loras
[
module_name
]
=
PackedLoRALayerWeights
.
pack
(
replacement_loras
)
# Remove the modules that have been replaced.
for
module
in
replaced_module
:
lora_model
.
loras
.
pop
(
module
,
None
)
def
deactivate_adapter
(
self
,
adapter_id
:
int
)
->
bool
:
return
deactivate_adapter
(
adapter_id
,
self
.
_active_adapters
,
...
...
vllm/lora/ops/triton_ops/kernel_utils.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
"""
Utilities for Punica kernel construction.
"""
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
mm_k
(
a_ptr
,
b_ptr
,
ak_stride
,
bk_stride
,
offset_k
,
K
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
b_dtype
:
tl
.
constexpr
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
"""
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
tiled_a
=
tl
.
load
(
a_ptr
)
tiled_b
=
tl
.
load
(
b_ptr
)
else
:
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
)
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
offset_k
[:,
None
]
<
K
-
k
*
(
BLOCK_K
*
SPLIT_K
),
other
=
0
)
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
b_dtype
)
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
,
)
a_ptr
+=
BLOCK_K
*
SPLIT_K
*
ak_stride
b_ptr
+=
BLOCK_K
*
SPLIT_K
*
bk_stride
return
accumulator
@
triton
.
jit
def
do_expand_kernel
(
pid_n
,
lora_index
,
slice_id
,
input_ptr
,
lora_ptr
,
out_ptr
,
N
,
K
,
M_LEN
,
ram
,
# array identifying the rows of Input ptr to operate on
slice_start_loc
,
# input ptr strides
input_d0_stride
,
input_d1_stride
,
input_d2_stride
,
# lora ptr strides
ls_d0_ptr
,
ls_d1_ptr
,
ls_d2_ptr
,
# out ptr strides
output_d0_stride
,
output_d1_stride
,
# constants
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
SAME_STRIDE
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if
SAME_STRIDE
:
# integer
cur_lora_d0_stride
=
ls_d0_ptr
cur_lora_d1_stride
=
ls_d1_ptr
cur_lora_d2_stride
=
ls_d2_ptr
else
:
# pointer
cur_lora_d0_stride
=
tl
.
load
(
ls_d0_ptr
+
slice_id
)
cur_lora_d1_stride
=
tl
.
load
(
ls_d1_ptr
+
slice_id
)
cur_lora_d2_stride
=
tl
.
load
(
ls_d2_ptr
+
slice_id
)
# Identify the input_ptr and lora_ptr from slice_id.
if
SLICE_NUM
==
1
:
cur_input_ptr
=
input_ptr
cur_lora_ptr
=
lora_ptr
else
:
cur_input_ptr
=
input_ptr
+
slice_id
*
input_d0_stride
cur_lora_ptr
=
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
out_ptr
.
dtype
.
element_ty
))
# Identify the column indices of B to process.
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
N
,
BLOCK_N
),
BLOCK_N
)
# Identify A and B block pointers
offset_k
=
tl
.
arange
(
0
,
BLOCK_K
)
a_ptr
=
(
cur_input_ptr
+
ram
[:,
None
]
*
input_d1_stride
+
offset_k
[
None
,
:]
*
input_d2_stride
,
)
b_ptr
=
(
cur_lora_ptr
+
cur_lora_d0_stride
*
lora_index
+
offset_k
[:,
None
]
*
cur_lora_d2_stride
+
rbn
[
None
,
:]
*
cur_lora_d1_stride
)
# Compute the block matrix product.
SPLIT_K
=
1
accumulator
=
mm_k
(
a_ptr
,
b_ptr
,
input_d2_stride
,
cur_lora_d2_stride
,
offset_k
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
CAST_TYPE
,
cur_lora_ptr
.
dtype
.
element_ty
)
tiled_c
=
accumulator
.
to
(
cur_lora_ptr
.
dtype
.
element_ty
)
if
SLICE_NUM
==
1
:
cur_slice_start
=
slice_start_loc
else
:
cur_slice_start
=
tl
.
load
(
slice_start_loc
+
slice_id
)
# Identify the C output pointers to store the results of the accumulator.
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
+
cur_slice_start
offset_cm
=
tl
.
arange
(
0
,
BLOCK_M
)
c_ptr
=
(
out_ptr
+
ram
[:,
None
]
*
output_d0_stride
+
offset_cn
[
None
,
:]
*
output_d1_stride
)
c_mask
=
(
offset_cm
[:,
None
]
<
M_LEN
)
&
(
offset_cn
[
None
,
:]
<
(
cur_slice_start
+
N
))
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
tiled_c
+=
tiled_out
tl
.
store
(
c_ptr
,
tiled_c
,
mask
=
c_mask
)
@
triton
.
jit
def
do_shrink_kernel
(
pid_n
,
pid_sk
,
slice_id
,
lora_index
,
input_ptr
,
lora_ptr
,
out_ptr
,
N
,
K
,
M_LEN
,
ram
,
# input strides
input_d0_stride
,
input_d1_stride
,
# lora strides
lora_d0_stride
,
lora_d1_stride
,
lora_d2_stride
,
# output strides
output_d0_stride
,
output_d1_stride
,
output_d2_stride
,
scaling
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
SLICE_NUM
:
tl
.
constexpr
,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if
SLICE_NUM
==
1
:
# current lora ptr
cur_lora_ptr
=
lora_ptr
else
:
# current lora ptr
cur_lora_ptr
=
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
input_ptr
.
dtype
.
element_ty
))
# Identify the column indices of B to process.
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
N
,
BLOCK_N
),
BLOCK_N
)
# Identify A and B block pointers
offset_k
=
pid_sk
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
a_ptr
=
(
input_ptr
+
ram
[:,
None
]
*
input_d0_stride
+
offset_k
[
None
,
:]
*
input_d1_stride
)
b_ptr
=
(
cur_lora_ptr
+
lora_d0_stride
*
lora_index
+
rbn
[
None
,
:]
*
lora_d1_stride
+
offset_k
[:,
None
]
*
lora_d2_stride
)
# Compute partial/complete block matrix product.
accumulator
=
mm_k
(
a_ptr
,
b_ptr
,
input_d1_stride
,
lora_d2_stride
,
offset_k
,
K
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
False
,
cur_lora_ptr
.
dtype
.
element_ty
)
# Identify the C output pointers to store the results of the accumulator.
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_cm
=
tl
.
arange
(
0
,
BLOCK_M
)
cur_out_ptr
=
(
out_ptr
if
SLICE_NUM
==
1
else
out_ptr
+
slice_id
*
output_d0_stride
)
c_ptr
=
cur_out_ptr
+
ram
[:,
None
]
*
output_d1_stride
+
offset_cn
[
None
,
:]
*
output_d2_stride
c_mask
=
(
offset_cm
[:,
None
]
<
M_LEN
)
&
(
offset_cn
[
None
,
:]
<
N
)
accumulator
*=
scaling
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
else
:
tl
.
atomic_add
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
vllm/lora/ops/triton_ops/sgmv_expand.py
View file @
ec5e299c
...
...
@@ -14,6 +14,7 @@ import triton.language as tl
from
vllm.utils
import
direct_register_custom_op
from
.kernel_utils
import
do_expand_kernel
from
.utils
import
_get_lora_b_ptr
...
...
@@ -63,86 +64,56 @@ def _sgmv_expand_kernel(
curr_N
=
N
if
SAME_STRIDE
else
tl
.
load
(
output_hs_ptr
+
slice_id
)
pid_m
=
pid
//
cta_n_num
pid_n
=
pid
%
cta_n_num
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
if
pid_m
*
BLOCK_M
>
M
:
if
pid_m
*
BLOCK_M
>
=
M
:
return
if
pid_n
*
BLOCK_N
>
curr_N
:
if
pid_n
*
BLOCK_N
>
=
curr_N
:
return
lora_index
=
tl
.
load
(
lora_indices
+
cur_batch
)
if
lora_index
==
-
1
:
return
cur_seq_start
=
tl
.
load
(
b_seq_start_loc
+
cur_batch
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_k
=
tl
.
arange
(
0
,
BLOCK_K
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_m
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
curr_N
,
BLOCK_N
),
BLOCK_N
)
# ls_d*_ptr can be either an integer or a pointer
if
SAME_STRIDE
:
# integer
cur_lora_d0_stride
=
ls_d0_ptr
cur_lora_d1_stride
=
ls_d1_ptr
cur_lora_d2_stride
=
ls_d2_ptr
else
:
# pointer
cur_lora_d0_stride
=
tl
.
load
(
ls_d0_ptr
+
slice_id
)
cur_lora_d1_stride
=
tl
.
load
(
ls_d1_ptr
+
slice_id
)
cur_lora_d2_stride
=
tl
.
load
(
ls_d2_ptr
+
slice_id
)
if
SLICE_NUM
==
1
:
cur_input_ptr
=
input_ptr
cur_lora_ptr
=
lora_ptr
else
:
cur_input_ptr
=
input_ptr
+
slice_id
*
input_d0_stride
cur_lora_ptr
=
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
out_ptr
.
dtype
.
element_ty
))
a_ptr
=
(
cur_input_ptr
+
cur_seq_start
*
input_d1_stride
+
ram
[:,
None
]
*
input_d1_stride
+
offset_k
[
None
,
:]
*
input_d2_stride
,
)
b_ptr
=
(
cur_lora_ptr
+
cur_lora_d0_stride
*
lora_index
+
offset_k
[:,
None
]
*
cur_lora_d2_stride
+
rbn
[
None
,
:]
*
cur_lora_d1_stride
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
tl
.
cdiv
(
K
,
BLOCK_K
)):
if
EVEN_K
:
tiled_a
=
tl
.
load
(
a_ptr
)
tiled_b
=
tl
.
load
(
b_ptr
)
else
:
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
offset_k
[:,
None
]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
cur_lora_ptr
.
dtype
.
element_ty
)
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
,
)
a_ptr
+=
BLOCK_K
*
input_d2_stride
b_ptr
+=
BLOCK_K
*
cur_lora_d2_stride
tiled_c
=
accumulator
.
to
(
cur_lora_ptr
.
dtype
.
element_ty
)
if
SLICE_NUM
==
1
:
cur_slice_start
=
slice_start_loc
else
:
cur_slice_start
=
tl
.
load
(
slice_start_loc
+
slice_id
)
offset_cm
=
cur_seq_start
+
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
+
cur_slice_start
c_ptr
=
(
out_ptr
+
offset_cm
[:,
None
]
*
output_d0_stride
+
offset_cn
[
None
,
:]
*
output_d1_stride
)
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
(
cur_slice_start
+
curr_N
))
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
tiled_c
+=
tiled_out
tl
.
store
(
c_ptr
,
tiled_c
,
mask
=
c_mask
)
m_offset
=
tl
.
load
(
b_seq_start_loc
+
cur_batch
)
cta_m_len
=
min
(
BLOCK_M
,
M
-
(
pid_m
*
BLOCK_M
))
cta_m_offset
=
m_offset
+
(
pid_m
*
BLOCK_M
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
ram
=
cta_m_offset
+
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_m
%
cta_m_len
,
BLOCK_M
),
BLOCK_M
)
do_expand_kernel
(
pid_n
,
lora_index
,
slice_id
,
input_ptr
,
lora_ptr
,
out_ptr
,
curr_N
,
K
,
cta_m_len
,
ram
,
# array identifying the rows of Input ptr to operate on
slice_start_loc
,
# input ptr strides
input_d0_stride
,
input_d1_stride
,
input_d2_stride
,
# lora ptr strides
ls_d0_ptr
,
ls_d1_ptr
,
ls_d2_ptr
,
# out ptr strides
output_d0_stride
,
output_d1_stride
,
# constants
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
SAME_STRIDE
,
SLICE_NUM
,
EVEN_K
,
CAST_TYPE
,
ADD_INPUTS
,
)
@
torch
.
inference_mode
()
...
...
vllm/lora/ops/triton_ops/sgmv_shrink.py
View file @
ec5e299c
...
...
@@ -14,6 +14,7 @@ import triton.language as tl
from
vllm.utils
import
direct_register_custom_op
from
.kernel_utils
import
do_shrink_kernel
from
.utils
import
_get_lora_a_ptr
...
...
@@ -62,67 +63,50 @@ def _sgmv_shrink_kernel(
pid_sk
=
pid_mix
%
SPLIT_K
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
if
pid_m
*
BLOCK_M
>
M
:
if
pid_m
*
BLOCK_M
>
=
M
:
return
lora_index
=
tl
.
load
(
lora_indices
+
cur_batch
)
if
lora_index
==
-
1
:
return
cur_seq_start
=
tl
.
load
(
b_seq_start_loc
+
cur_batch
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_k
=
pid_sk
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_m
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
N
,
BLOCK_N
),
BLOCK_N
)
# input ptr
a_ptr
=
(
input_ptr
+
cur_seq_start
*
input_d0_stride
+
ram
[:,
None
]
*
input_d0_stride
+
offset_k
[
None
,
:]
*
input_d1_stride
)
if
SLICE_NUM
==
1
:
# current lora ptr
cur_lora_ptr
=
lora_ptr
else
:
# current lora ptr
cur_lora_ptr
=
tl
.
load
(
lora_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
input_ptr
.
dtype
.
element_ty
))
b_ptr
=
(
cur_lora_ptr
+
lora_d0_stride
*
lora_index
+
rbn
[
None
,
:]
*
lora_d1_stride
+
offset_k
[:,
None
]
*
lora_d2_stride
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
tiled_a
=
tl
.
load
(
a_ptr
)
tiled_b
=
tl
.
load
(
b_ptr
)
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
k_remaining
,
other
=
0.0
)
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
offset_k
[:,
None
]
<
k_remaining
,
other
=
0.0
)
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
)
a_ptr
+=
BLOCK_K
*
SPLIT_K
*
input_d1_stride
b_ptr
+=
BLOCK_K
*
SPLIT_K
*
lora_d2_stride
offset_cm
=
cur_seq_start
+
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
cur_out_ptr
=
(
out_ptr
if
SLICE_NUM
==
1
else
out_ptr
+
slice_id
*
output_d0_stride
)
c_ptr
=
cur_out_ptr
+
offset_cm
[:,
None
]
*
output_d1_stride
+
offset_cn
[
None
,
:]
*
output_d2_stride
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
N
)
accumulator
*=
scaling
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
else
:
tl
.
atomic_add
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
m_offset
=
tl
.
load
(
b_seq_start_loc
+
cur_batch
)
cta_m_len
=
min
(
BLOCK_M
,
M
-
(
pid_m
*
BLOCK_M
))
cta_m_offset
=
m_offset
+
(
pid_m
*
BLOCK_M
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
ram
=
cta_m_offset
+
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_m
%
cta_m_len
,
BLOCK_M
),
BLOCK_M
)
do_shrink_kernel
(
pid_n
,
pid_sk
,
slice_id
,
lora_index
,
input_ptr
,
lora_ptr
,
out_ptr
,
N
,
K
,
cta_m_len
,
ram
,
# input strides
input_d0_stride
,
input_d1_stride
,
# lora strides
lora_d0_stride
,
lora_d1_stride
,
lora_d2_stride
,
# output strides
output_d0_stride
,
output_d1_stride
,
output_d2_stride
,
scaling
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
SLICE_NUM
)
@
torch
.
inference_mode
()
...
...
vllm/lora/punica_wrapper/punica_base.py
View file @
ec5e299c
...
...
@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
dtype
=
torch
.
long
,
device
=
device
)
# 5 is the number of indic
i
es tensors.
# 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self
.
indices_len
:
List
[
Optional
[
int
]]
=
[
None
]
*
5
...
...
vllm/lora/punica_wrapper/punica_hpu.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
,
Union
,
final
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
,
final
import
torch
from
vllm_hpu_extension.ops
import
(
dispatch_bgmv_embedding
,
dispatch_bgmv_linear
)
from
.punica_base
import
PunicaWrapperBase
from
.utils
import
convert_mapping
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.models
import
LongContextLoRAContext
@
final
...
...
@@ -19,6 +25,55 @@ class PunicaWrapperHPU(PunicaWrapperBase):
PunicaWrapperBase
.
__init__
(
self
,
3
*
max_num_batched_tokens
,
max_batches
,
device
)
def
_update_base_metadata
(
self
,
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
List
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
):
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_offsets_tensor
,
indices_len
,
)
=
convert_mapping
(
mapping
,
lora_index_to_id
,
max_loras
,
vocab_size
,
extra_vocab_size
,
self
.
device
,
None
)
# Updating each element in `long_lora_offsets` with `lora_offset` slows
# down perf in HPU due to a series of `strided_insert` ops during lazy
# graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready.
if
long_lora_context
:
index_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
long_lora_offsets
:
List
[
int
]
=
[]
for
i
in
range
(
len
(
index_mapping_indices
)):
lora_offset
:
int
=
long_lora_context
.
offsets_by_lora_id
.
get
(
index_mapping_indices
[
i
],
0
)
long_lora_offsets
.
append
(
lora_offset
)
long_lora_offsets_tensor
=
torch
.
tensor
(
long_lora_offsets
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
indices_len
[
-
1
]
=
long_lora_offsets_tensor
.
shape
[
-
1
]
self
.
_token_lora_indices
[:
base_indices
.
shape
[
0
]].
copy_
(
base_indices
)
self
.
_sampler_indices
[:
sampler_indices
.
shape
[
0
]].
copy_
(
sampler_indices
)
self
.
_sampler_indices_padded
[:
sampler_indices_padded
.
shape
[
0
]].
copy_
(
sampler_indices_padded
)
self
.
_embeddings_indices
[:
embeddings_indices
.
shape
[
0
],
:
embeddings_indices
.
shape
[
1
]].
copy_
(
embeddings_indices
)
if
long_lora_offsets_tensor
is
not
None
:
self
.
_long_lora_indices
[:
long_lora_offsets_tensor
.
shape
[
0
]].
copy_
(
long_lora_offsets_tensor
)
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
def
add_lora_embedding
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment