Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
3b7178cf
Unverified
Commit
3b7178cf
authored
Feb 28, 2024
by
Liangfu Chen
Committed by
GitHub
Feb 28, 2024
Browse files
[Neuron] Support inference with transformers-neuronx (#2569)
parent
e46fa5d5
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
516 additions
and
42 deletions
+516
-42
examples/offline_inference_neuron.py
examples/offline_inference_neuron.py
+33
-0
tests/lora/conftest.py
tests/lora/conftest.py
+5
-3
vllm/config.py
vllm/config.py
+35
-6
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+7
-9
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+18
-3
vllm/lora/layers.py
vllm/lora/layers.py
+4
-0
vllm/model_executor/__init__.py
vllm/model_executor/__init__.py
+1
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+13
-5
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+5
-5
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+11
-1
vllm/model_executor/models/neuron/llama.py
vllm/model_executor/models/neuron/llama.py
+79
-0
vllm/model_executor/neuron_model_loader.py
vllm/model_executor/neuron_model_loader.py
+66
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+2
-2
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+17
-0
vllm/utils.py
vllm/utils.py
+8
-0
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+9
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+12
-4
vllm/worker/neuron_worker.py
vllm/worker/neuron_worker.py
+191
-0
No files found.
examples/offline_inference_neuron.py
0 → 100644
View file @
3b7178cf
from
vllm
import
LLM
,
SamplingParams
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create an LLM.
llm
=
LLM
(
model
=
"openlm-research/open_llama_3b"
,
max_num_seqs
=
8
,
# The max_model_len and block_size arguments are required to be same as max sequence length,
# when targeting neuron device. Currently, this is a known limitation in continuous batching
# support in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len
=
128
,
block_size
=
128
,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, or explicitly assigned.
device
=
"neuron"
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
tests/lora/conftest.py
View file @
3b7178cf
...
@@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
...
@@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup
()
cleanup
()
get_model_old
=
get_model
get_model_old
=
get_model
def
get_model_patched
(
model_config
,
device_config
,
lora_config
=
None
):
def
get_model_patched
(
model_config
,
device_config
,
**
kwargs
):
return
get_model_old
(
model_config
,
device_config
,
return
get_model_old
(
model_config
,
LoRAConfig
(
max_loras
=
4
,
max_lora_rank
=
8
))
device_config
,
lora_config
=
LoRAConfig
(
max_loras
=
4
,
max_lora_rank
=
8
))
with
patch
(
"vllm.worker.model_runner.get_model"
,
get_model_patched
):
with
patch
(
"vllm.worker.model_runner.get_model"
,
get_model_patched
):
engine
=
vllm
.
LLM
(
"meta-llama/Llama-2-7b-hf"
,
enable_lora
=
False
)
engine
=
vllm
.
LLM
(
"meta-llama/Llama-2-7b-hf"
,
enable_lora
=
False
)
...
...
vllm/config.py
View file @
3b7178cf
...
@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
...
@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.config
import
get_config
from
vllm.utils
import
get_cpu_memory
,
is_hip
,
get_nvcc_cuda_version
from
vllm.utils
import
get_cpu_memory
,
is_hip
,
is_neuron
,
get_nvcc_cuda_version
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -380,13 +380,21 @@ class ParallelConfig:
...
@@ -380,13 +380,21 @@ class ParallelConfig:
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
if
is_neuron
():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self
.
tensor_parallel_size
=
1
self
.
neuron_tp_degree
=
tensor_parallel_size
else
:
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
worker_use_ray
=
worker_use_ray
self
.
worker_use_ray
=
worker_use_ray
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
disable_custom_all_reduce
=
disable_custom_all_reduce
self
.
disable_custom_all_reduce
=
disable_custom_all_reduce
self
.
world_size
=
pipeline_parallel_size
*
tensor_parallel_size
self
.
world_size
=
pipeline_parallel_size
*
self
.
tensor_parallel_size
if
self
.
world_size
>
1
:
# Ray worker is not supported for Neuron backend.
if
self
.
world_size
>
1
and
not
is_neuron
():
self
.
worker_use_ray
=
True
self
.
worker_use_ray
=
True
self
.
_verify_args
()
self
.
_verify_args
()
...
@@ -465,8 +473,29 @@ class SchedulerConfig:
...
@@ -465,8 +473,29 @@ class SchedulerConfig:
class
DeviceConfig
:
class
DeviceConfig
:
def
__init__
(
self
,
device
:
str
=
"cuda"
)
->
None
:
def
__init__
(
self
,
device
:
str
=
"auto"
)
->
None
:
self
.
device
=
torch
.
device
(
device
)
if
device
==
"auto"
:
# Automated device type detection
if
torch
.
cuda
.
is_available
():
self
.
device_type
=
"cuda"
elif
is_neuron
():
self
.
device_type
=
"neuron"
else
:
raise
RuntimeError
(
"No supported device detected."
)
else
:
# Device type is assigned explicitly
self
.
device_type
=
device
# Some device types require processing inputs on CPU
if
self
.
device_type
in
[
"neuron"
]:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
# Set device with device type
self
.
device
=
torch
.
device
(
self
.
device_type
)
@
property
def
is_neuron
(
self
):
return
self
.
device_type
==
"neuron"
@
dataclass
@
dataclass
...
...
vllm/engine/arg_utils.py
View file @
3b7178cf
...
@@ -44,7 +44,7 @@ class EngineArgs:
...
@@ -44,7 +44,7 @@ class EngineArgs:
lora_extra_vocab_size
:
int
=
256
lora_extra_vocab_size
:
int
=
256
lora_dtype
=
'auto'
lora_dtype
=
'auto'
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'
cuda
'
device
:
str
=
'
auto
'
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
...
@@ -171,7 +171,7 @@ class EngineArgs:
...
@@ -171,7 +171,7 @@ class EngineArgs:
parser
.
add_argument
(
'--block-size'
,
parser
.
add_argument
(
'--block-size'
,
type
=
int
,
type
=
int
,
default
=
EngineArgs
.
block_size
,
default
=
EngineArgs
.
block_size
,
choices
=
[
8
,
16
,
32
],
choices
=
[
8
,
16
,
32
,
128
],
help
=
'token block size'
)
help
=
'token block size'
)
parser
.
add_argument
(
'--seed'
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
type
=
int
,
...
@@ -264,13 +264,11 @@ class EngineArgs:
...
@@ -264,13 +264,11 @@ class EngineArgs:
help
=
(
'Maximum number of LoRAs to store in CPU memory. '
help
=
(
'Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'
))
'Defaults to max_num_seqs.'
))
parser
.
add_argument
(
parser
.
add_argument
(
"--device"
,
"--device"
,
type
=
str
,
type
=
str
,
default
=
EngineArgs
.
device
,
default
=
EngineArgs
.
device
,
choices
=
[
"auto"
,
"cuda"
,
"neuron"
],
choices
=
[
"cuda"
],
help
=
'Device type for vLLM execution.'
)
help
=
(
'Device type for vLLM execution. '
'Currently, only CUDA-compatible devices are supported.'
))
return
parser
return
parser
@
classmethod
@
classmethod
...
...
vllm/engine/llm_engine.py
View file @
3b7178cf
...
@@ -3,6 +3,7 @@ from collections import defaultdict
...
@@ -3,6 +3,7 @@ from collections import defaultdict
import
os
import
os
import
time
import
time
import
pickle
import
pickle
import
importlib
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
)
Union
)
...
@@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
...
@@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
TokenizerGroup
)
TokenizerGroup
)
from
vllm.utils
import
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
,
get_distributed_init_method
from
vllm.utils
import
(
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
,
get_distributed_init_method
)
if
ray
:
if
ray
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
@@ -31,6 +33,12 @@ if TYPE_CHECKING:
...
@@ -31,6 +33,12 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
_LOCAL_LOGGING_INTERVAL_SEC
=
5
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP
=
{
"cuda"
:
"vllm.worker.worker"
,
"neuron"
:
"vllm.worker.neuron_worker"
,
}
# If the env var is set, it uses the Ray's compiled DAG API
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
...
@@ -138,10 +146,17 @@ class LLMEngine:
...
@@ -138,10 +146,17 @@ class LLMEngine:
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
):
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
):
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
return
self
.
tokenizer
.
get_lora_tokenizer
(
sequence
.
lora_request
)
def
_dispatch_worker
(
self
):
worker_module
=
DEVICE_TO_WORKER_MODULE_MAP
[
self
.
device_config
.
device_type
]
imported_worker
=
importlib
.
import_module
(
worker_module
)
Worker
=
imported_worker
.
Worker
return
Worker
def
_init_workers
(
self
):
def
_init_workers
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
W
orker
Worker
=
self
.
_dispatch_w
orker
()
assert
self
.
parallel_config
.
world_size
==
1
,
(
assert
self
.
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
"Ray is required if parallel_config.world_size > 1."
)
...
@@ -243,7 +258,7 @@ class LLMEngine:
...
@@ -243,7 +258,7 @@ class LLMEngine:
# Lazy import the Worker to avoid importing torch.cuda/xformers
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
W
orker
Worker
=
self
.
_dispatch_w
orker
()
# Initialize torch distributed process group for the workers.
# Initialize torch distributed process group for the workers.
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
...
...
vllm/lora/layers.py
View file @
3b7178cf
...
@@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
...
@@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
@
property
def
logits_as_hidden_states
(
self
):
return
self
.
base_layer
.
logits_as_hidden_states
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
self
.
base_layer
.
vocab_size
return
self
.
base_layer
.
vocab_size
...
...
vllm/model_executor/__init__.py
View file @
3b7178cf
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
,
get_model
__all__
=
[
__all__
=
[
"InputMetadata"
,
"InputMetadata"
,
...
...
vllm/model_executor/layers/sampler.py
View file @
3b7178cf
...
@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
...
@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.utils
import
is_neuron
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -32,6 +33,8 @@ class Sampler(nn.Module):
...
@@ -32,6 +33,8 @@ class Sampler(nn.Module):
org_vocab_size
:
Optional
[
int
]
=
None
)
->
None
:
org_vocab_size
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
# Transformers-neuronx generate outputs as logits directly.
self
.
logits_as_hidden_states
=
is_neuron
()
# original vocabulary size (without LoRA).
# original vocabulary size (without LoRA).
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
...
@@ -55,10 +58,14 @@ class Sampler(nn.Module):
...
@@ -55,10 +58,14 @@ class Sampler(nn.Module):
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
# Get the hidden states that we use for sampling.
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
if
self
.
logits_as_hidden_states
:
logits
=
hidden_states
else
:
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
# Only perform sampling in the driver worker.
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# Note: `_get_logits` is still distributed across TP workers because
...
@@ -395,7 +402,8 @@ def _sample(
...
@@ -395,7 +402,8 @@ def _sample(
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
sample_metadata
[
sampling_type
]
=
(
seq_group_ids
,
seq_groups
,
is_prompts
,
sample_indices
)
is_prompts
,
sample_indices
)
if
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
],
dim
=-
1
)
greedy_samples
=
torch
.
argmax
(
logprobs
[
sample_indices
.
long
()],
dim
=-
1
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_best_of
=
1
max_best_of
=
1
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
for
seq_group
,
is_prompt
in
zip
(
seq_groups
,
is_prompts
):
...
@@ -407,7 +415,7 @@ def _sample(
...
@@ -407,7 +415,7 @@ def _sample(
"generators"
:
sampling_metadata
.
generators
,
"generators"
:
sampling_metadata
.
generators
,
}
}
multinomial_samples
[
sampling_type
]
=
_multinomial
(
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
sample_indices
],
max_best_of
,
**
seeded_args
)
probs
[
sample_indices
.
long
()
],
max_best_of
,
**
seeded_args
)
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
else
:
...
...
vllm/model_executor/model_loader.py
View file @
3b7178cf
"""Utilities for selecting and loading models."""
"""Utilities for selecting and loading models."""
import
contextlib
import
contextlib
from
typing
import
Optional
,
Type
from
typing
import
Type
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
DeviceConfig
,
ModelConfig
,
LoRAConfig
from
vllm.config
import
DeviceConfig
,
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
initialize_dummy_weights
)
...
@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
...
@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
,
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
lora_config
:
Optional
[
LoRAC
onfig
]
=
None
)
->
nn
.
Module
:
lora_config
=
kwargs
.
get
(
"lora_c
onfig
"
,
None
)
model_class
=
_get_model_architecture
(
model_config
)
model_class
=
_get_model_architecture
(
model_config
)
# Get the (maybe quantized) linear method.
# Get the (maybe quantized) linear method.
...
...
vllm/model_executor/models/__init__.py
View file @
3b7178cf
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Type
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Type
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
is_neuron
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -61,6 +61,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
...
@@ -61,6 +61,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention"
,
"Sliding window attention is not yet supported in ROCm's flash attention"
,
}
}
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS
=
{
"LlamaForCausalLM"
:
"neuron.llama"
}
class
ModelRegistry
:
class
ModelRegistry
:
...
@@ -77,8 +80,15 @@ class ModelRegistry:
...
@@ -77,8 +80,15 @@ class ModelRegistry:
logger
.
warning
(
logger
.
warning
(
f
"Model architecture
{
model_arch
}
is partially supported "
f
"Model architecture
{
model_arch
}
is partially supported "
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
"by ROCm: "
+
_ROCM_PARTIALLY_SUPPORTED_MODELS
[
model_arch
])
elif
is_neuron
():
if
model_arch
not
in
_NEURON_SUPPORTED_MODELS
:
raise
ValueError
(
f
"Model architecture
{
model_arch
}
is not supported by "
"Neuron for now."
)
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
module_name
,
model_cls_name
=
_MODELS
[
model_arch
]
if
is_neuron
():
module_name
=
_NEURON_SUPPORTED_MODELS
[
model_arch
]
module
=
importlib
.
import_module
(
module
=
importlib
.
import_module
(
f
"vllm.model_executor.models.
{
module_name
}
"
)
f
"vllm.model_executor.models.
{
module_name
}
"
)
return
getattr
(
module
,
model_cls_name
,
None
)
return
getattr
(
module
,
model_cls_name
,
None
)
...
...
vllm/model_executor/models/neuron/llama.py
0 → 100644
View file @
3b7178cf
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
os
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
LlamaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
linear_method
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
None
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
with
torch
.
inference_mode
():
block_size
=
self
.
model
.
context_buckets
[
-
1
]
if
input_metadata
.
is_prompt
:
seq_ids
=
input_metadata
.
slot_mapping
[:,
0
]
//
block_size
else
:
seq_ids
=
input_metadata
.
block_tables
logits
=
self
.
model
(
input_ids
,
cache_ids
=
positions
,
start_ids
=
seq_ids
.
flatten
())
return
logits
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
chkpt_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
):
from
transformers_neuronx.llama.model
import
LlamaForSampling
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
from
transformers.models.llama
import
LlamaForCausalLM
from
transformers_neuronx.module
import
save_pretrained_split
hf_model
=
LlamaForCausalLM
.
from_pretrained
(
model_name_or_path
,
low_cpu_mem_usage
=
True
)
save_pretrained_split
(
hf_model
,
f
"
{
model_name_or_path
}
-split"
)
self
.
model
=
LlamaForSampling
.
from_pretrained
(
split_model_dir
,
**
kwargs
)
self
.
model
.
to_neuron
()
vllm/model_executor/neuron_model_loader.py
0 → 100644
View file @
3b7178cf
"""Utilities for selecting and loading models."""
from
typing
import
Type
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
DeviceConfig
from
vllm.model_executor.models
import
ModelRegistry
TORCH_DTYPE_TO_NEURON_AMP
=
{
"auto"
:
"f32"
,
"half"
:
"f16"
,
"float16"
:
"f16"
,
"bfloat16"
:
"bf16"
,
"float"
:
"f32"
,
"float32"
:
"f32"
,
torch
.
float16
:
"f16"
,
torch
.
bfloat16
:
"bf16"
,
torch
.
float32
:
"f32"
,
}
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
model_cls
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
from
transformers_neuronx.config
import
NeuronConfig
,
ContinuousBatchingConfig
parallel_config
=
kwargs
.
get
(
"parallel_config"
)
scheduler_config
=
kwargs
.
get
(
"scheduler_config"
)
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
linear_method
=
None
# Create a model instance.
model
=
model_class
(
model_config
.
hf_config
,
linear_method
)
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
neuron_config
=
NeuronConfig
(
continuous_batching
=
continuous_batching_config
)
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
load_format
,
model_config
.
revision
,
tp_degree
=
parallel_config
.
neuron_tp_degree
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
neuron_config
=
neuron_config
,
context_length_estimate
=
[
scheduler_config
.
max_model_len
],
n_positions
=
[
scheduler_config
.
max_model_len
],
batch_size
=
scheduler_config
.
max_num_seqs
)
return
model
.
eval
()
vllm/model_executor/sampling_metadata.py
View file @
3b7178cf
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
in_wsl
from
vllm.utils
import
in_wsl
,
is_neuron
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -155,7 +155,7 @@ class SamplingTensors:
...
@@ -155,7 +155,7 @@ class SamplingTensors:
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
pin_memory
=
not
in_wsl
()
pin_memory
=
not
in_wsl
()
and
not
is_neuron
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_padded_tokens
=
[
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
...
...
vllm/model_executor/utils.py
View file @
3b7178cf
"""Utils for model executor."""
"""Utils for model executor."""
import
random
import
random
import
importlib
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.config
import
DeviceConfig
,
ModelConfig
DEVICE_TO_MODEL_LOADER_MAP
=
{
"cuda"
:
"model_loader"
,
"neuron"
:
"neuron_model_loader"
,
}
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
random
.
seed
(
seed
)
...
@@ -33,3 +41,12 @@ def set_weight_attrs(
...
@@ -33,3 +41,12 @@ def set_weight_attrs(
assert
not
hasattr
(
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
setattr
(
weight
,
key
,
value
)
setattr
(
weight
,
key
,
value
)
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
torch
.
nn
.
Module
:
model_loader_module
=
DEVICE_TO_MODEL_LOADER_MAP
[
device_config
.
device_type
]
imported_model_loader
=
importlib
.
import_module
(
f
"vllm.model_executor.
{
model_loader_module
}
"
)
get_model_fn
=
imported_model_loader
.
get_model
return
get_model_fn
(
model_config
,
device_config
,
**
kwargs
)
vllm/utils.py
View file @
3b7178cf
...
@@ -118,6 +118,14 @@ def is_hip() -> bool:
...
@@ -118,6 +118,14 @@ def is_hip() -> bool:
return
torch
.
version
.
hip
is
not
None
return
torch
.
version
.
hip
is
not
None
def
is_neuron
()
->
bool
:
try
:
import
transformers_neuronx
except
ImportError
:
transformers_neuronx
=
None
return
transformers_neuronx
is
not
None
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
"""Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since
# NOTE: This import statement should be executed lazily since
...
...
vllm/worker/cache_engine.py
View file @
3b7178cf
...
@@ -3,10 +3,9 @@ from typing import Dict, List, Tuple
...
@@ -3,10 +3,9 @@ from typing import Dict, List, Tuple
import
torch
import
torch
from
vllm._C
import
cache_ops
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
in_wsl
,
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
in_wsl
,
is_neuron
,
STR_DTYPE_TO_TORCH_DTYPE
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -39,6 +38,10 @@ class CacheEngine:
...
@@ -39,6 +38,10 @@ class CacheEngine:
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
num_cpu_blocks
=
cache_config
.
num_cpu_blocks
self
.
num_cpu_blocks
=
cache_config
.
num_cpu_blocks
# Skip initializing CUDA stream and buffer for Neuron backend.
if
is_neuron
():
return
if
cache_config
.
cache_dtype
==
"auto"
:
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
dtype
=
model_config
.
dtype
self
.
dtype
=
model_config
.
dtype
else
:
else
:
...
@@ -121,6 +124,8 @@ class CacheEngine:
...
@@ -121,6 +124,8 @@ class CacheEngine:
dst
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
src_to_dst
:
Dict
[
int
,
int
],
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
)
->
None
:
from
vllm._C
import
cache_ops
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
src_key_cache
,
src_value_cache
=
src
[
i
]
src_key_cache
,
src_value_cache
=
src
[
i
]
...
@@ -140,6 +145,8 @@ class CacheEngine:
...
@@ -140,6 +145,8 @@ class CacheEngine:
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
from
vllm._C
import
cache_ops
key_caches
=
[
key_cache
for
key_cache
,
_
in
self
.
gpu_cache
]
key_caches
=
[
key_cache
for
key_cache
,
_
in
self
.
gpu_cache
]
value_caches
=
[
value_cache
for
_
,
value_cache
in
self
.
gpu_cache
]
value_caches
=
[
value_cache
for
_
,
value_cache
in
self
.
gpu_cache
]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
...
...
vllm/worker/model_runner.py
View file @
3b7178cf
...
@@ -80,9 +80,16 @@ class ModelRunner:
...
@@ -80,9 +80,16 @@ class ModelRunner:
self
.
in_wsl
=
in_wsl
()
self
.
in_wsl
=
in_wsl
()
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
if
self
.
device_config
.
is_neuron
:
self
.
model_config
.
enforce_eager
=
True
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
self
.
model_config
,
self
.
device_config
,
self
.
model
=
get_model
(
self
.
model_config
,
self
.
lora_config
)
self
.
device_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
vocab_size
=
self
.
model
.
config
.
vocab_size
vocab_size
=
self
.
model
.
config
.
vocab_size
...
@@ -393,6 +400,7 @@ class ModelRunner:
...
@@ -393,6 +400,7 @@ class ModelRunner:
selected_token_start_idx
=
0
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sample_indices_start_idx
=
0
pin_memory
=
not
self
.
in_wsl
and
not
self
.
device_config
.
is_neuron
max_subquery_len
=
max
(
subquery_lens
)
if
subquery_lens
else
1
max_subquery_len
=
max
(
subquery_lens
)
if
subquery_lens
else
1
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
...
@@ -443,12 +451,12 @@ class ModelRunner:
...
@@ -443,12 +451,12 @@ class ModelRunner:
selected_token_indices
=
_async_h2d
(
selected_token_indices
,
selected_token_indices
=
_async_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
target_device
=
self
.
device
,
target_device
=
self
.
device
,
pin_memory
=
not
self
.
in_wsl
)
pin_memory
=
pin_memory
)
categorized_sample_indices
=
{
categorized_sample_indices
=
{
t
:
_async_h2d
(
seq_ids
,
t
:
_async_h2d
(
seq_ids
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
target_device
=
self
.
device
,
target_device
=
self
.
device
,
pin_memory
=
not
self
.
in_wsl
)
pin_memory
=
pin_memory
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
}
...
...
vllm/worker/neuron_worker.py
0 → 100644
View file @
3b7178cf
"""A Neuron worker class."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
ensure_model_parallel_initialized
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
class
Worker
:
"""A worker class that executes the model on a group of neuron cores.
"""
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
lora_config
=
lora_config
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
lora_config
=
self
.
lora_config
,
is_driver_worker
=
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
def
init_model
(
self
)
->
None
:
# Initialize the distributed environment.
_init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
,
distributed_backend
=
"gloo"
)
# Initialize the model.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
@
torch
.
inference_mode
()
def
profile_num_available_blocks
(
self
,
block_size
:
int
=
128
,
gpu_memory_utilization
:
float
=
0.9
,
cpu_swap_space
:
int
=
0
,
cache_dtype
:
str
=
"float16"
,
)
->
Tuple
[
int
,
int
]:
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks."""
num_gpu_blocks
=
self
.
scheduler_config
.
max_num_seqs
num_cpu_blocks
=
0
return
num_gpu_blocks
,
num_cpu_blocks
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
model_runner
.
set_block_size
(
self
.
cache_engine
.
block_size
)
def
warm_up_model
(
self
)
->
None
:
# Warm up is maintained in transformers-neuronx
pass
def
cache_swap
(
self
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
# Issue cache operations.
issued_cache_op
=
False
if
blocks_to_swap_in
:
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
issued_cache_op
=
True
if
blocks_to_swap_out
:
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
issued_cache_op
=
True
if
blocks_to_copy
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
issued_cache_op
=
True
cache_events
=
self
.
cache_events
if
issued_cache_op
else
None
# Wait for cache operations to finish.
if
cache_events
is
not
None
:
raise
NotImplementedError
(
"cache operations are not implemented for neuron backend."
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
data
=
{
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
data
=
broadcast_tensor_dict
(
src
=
0
)
num_seq_groups
=
data
[
"num_seq_groups"
]
blocks_to_swap_in
=
data
[
"blocks_to_swap_in"
]
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
return
{}
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
gpu_cache
)
return
output
def
_init_distributed_environment
(
parallel_config
:
ParallelConfig
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
distributed_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Initialize the distributed environment."""
if
torch
.
distributed
.
is_initialized
():
torch_world_size
=
torch
.
distributed
.
get_world_size
()
if
torch_world_size
!=
parallel_config
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f
"(
{
torch_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
not
distributed_init_method
:
raise
ValueError
(
"distributed_init_method must be set if torch.distributed "
"is not already initialized"
)
else
:
distributed_backend
=
distributed_backend
if
distributed_backend
else
"nccl"
torch
.
distributed
.
init_process_group
(
backend
=
distributed_backend
,
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
init_method
=
distributed_init_method
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
))
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment