Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d6123170
Unverified
Commit
d6123170
authored
Mar 10, 2025
by
gnovack
Committed by
GitHub
Mar 10, 2025
Browse files
[Neuron] Add Neuron device communicator for vLLM v1 (#14085)
parent
485afdd3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
127 additions
and
0 deletions
+127
-0
tests/neuron/test_comm_ops.py
tests/neuron/test_comm_ops.py
+100
-0
vllm/distributed/device_communicators/neuron_communicator.py
vllm/distributed/device_communicators/neuron_communicator.py
+19
-0
vllm/platforms/neuron.py
vllm/platforms/neuron.py
+8
-0
No files found.
tests/neuron/test_comm_ops.py
0 → 100644
View file @
d6123170
# SPDX-License-Identifier: Apache-2.0
import
functools
from
typing
import
Callable
from
unittest.mock
import
patch
import
pytest
import
torch
import
torch_xla.distributed.xla_multiprocessing
as
xmp
from
typing_extensions
import
ParamSpec
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.utils
import
get_distributed_init_method
,
get_open_port
_P
=
ParamSpec
(
"_P"
)
def
reinitialize_neuron_runtime
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to reinitialize the Neuron Runtime before executing a test.
This is necessary for distributed tests which need to reallocate Neuron
Cores to separate subprocesses.
"""
@
functools
.
wraps
(
f
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
runtime
=
torch
.
classes
.
neuron
.
Runtime
()
runtime
.
initialize
()
runtime
.
unsafe_close
()
f
(
*
args
,
**
kwargs
)
runtime
.
initialize
()
return
wrapper
def
all_gather_test_worker
(
index
,
tp_degree
,
distributed_init_method
):
init_distributed_environment
(
tp_degree
,
index
,
distributed_init_method
,
index
,
backend
=
"xla"
)
ensure_model_parallel_initialized
(
tp_degree
,
1
)
num_dimensions
=
3
tensor_size
=
list
(
range
(
2
,
num_dimensions
+
2
))
total_size
=
1
for
s
in
tensor_size
:
total_size
*=
s
all_gather_dimension
=
-
1
all_tensors
=
[
torch
.
arange
(
total_size
,
dtype
=
torch
.
float32
,
device
=
"xla"
).
reshape
(
tensor_size
)
*
(
r
+
1
)
for
r
in
range
(
tp_degree
)
]
expected
=
torch
.
cat
(
all_tensors
,
dim
=
all_gather_dimension
)
t
=
all_tensors
[
index
%
tp_degree
]
t
=
tensor_model_parallel_all_gather
(
t
,
all_gather_dimension
)
torch
.
testing
.
assert_close
(
t
,
expected
)
def
all_reduce_test_worker
(
index
,
tp_degree
,
distributed_init_method
):
init_distributed_environment
(
tp_degree
,
index
,
distributed_init_method
,
index
,
backend
=
"xla"
)
ensure_model_parallel_initialized
(
tp_degree
,
1
)
num_elements
=
8
all_tensors
=
[
torch
.
arange
(
num_elements
,
dtype
=
torch
.
float32
,
device
=
"xla"
)
*
(
r
+
1
)
for
r
in
range
(
tp_degree
)
]
expected
=
torch
.
sum
(
torch
.
stack
(
all_tensors
,
dim
=
0
),
dim
=
0
)
t
=
all_tensors
[
index
%
tp_degree
]
t
=
tensor_model_parallel_all_reduce
(
t
)
torch
.
testing
.
assert_close
(
t
,
expected
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
all_reduce_test_worker
,
all_gather_test_worker
])
@
reinitialize_neuron_runtime
def
test_neuron_multi_process_tensor_parallel
(
monkeypatch
,
tp_size
,
test_target
):
with
patch
(
'torch_xla._XLAC._xla_runtime_is_initialized'
,
return_value
=
False
):
distributed_init_method
=
get_distributed_init_method
(
"127.0.0.1"
,
get_open_port
())
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"NEURONCORE_NUM_DEVICES"
,
str
(
tp_size
))
monkeypatch
.
setenv
(
"NEURON_PJRT_PROCESSES_NUM_DEVICES"
,
','
.
join
([
'1'
for
_
in
range
(
tp_size
)]))
xmp
.
spawn
(
test_target
,
args
=
(
tp_size
,
distributed_init_method
))
vllm/distributed/device_communicators/neuron_communicator.py
0 → 100644
View file @
d6123170
# SPDX-License-Identifier: Apache-2.0
import
torch
from
vllm.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
)
from
vllm.platforms
import
current_platform
if
current_platform
.
is_neuron
():
import
torch_xla.core.xla_model
as
xm
class
NeuronCommunicator
(
DeviceCommunicatorBase
):
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
x
)
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"Neuron only supports dim=-1 for all-gather."
return
xm
.
all_gather
(
x
,
dim
=
dim
)
vllm/platforms/neuron.py
View file @
d6123170
...
...
@@ -2,6 +2,7 @@
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
...
...
@@ -56,6 +57,13 @@ class NeuronPlatform(Platform):
logger
.
warning
(
"Pin memory is not supported on Neuron."
)
return
False
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
if
envs
.
VLLM_USE_V1
:
return
"vllm.distributed.device_communicators.neuron_communicator.NeuronCommunicator"
# noqa
else
:
return
Platform
.
get_device_communicator_cls
()
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
return
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment