Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Video-Generation-Model
Commits
c07946d8
Commit
c07946d8
authored
Apr 09, 2026
by
hepj
Browse files
dit & video
parents
Changes
270
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2862 additions
and
0 deletions
+2862
-0
FastVideo-main/fastvideo/v1/configs/sample/hunyuan.py
FastVideo-main/fastvideo/v1/configs/sample/hunyuan.py
+20
-0
FastVideo-main/fastvideo/v1/configs/sample/registry.py
FastVideo-main/fastvideo/v1/configs/sample/registry.py
+75
-0
FastVideo-main/fastvideo/v1/configs/sample/wan.py
FastVideo-main/fastvideo/v1/configs/sample/wan.py
+24
-0
FastVideo-main/fastvideo/v1/default_configs/v1_inference_hunyuan_config.yaml
...video/v1/default_configs/v1_inference_hunyuan_config.yaml
+17
-0
FastVideo-main/fastvideo/v1/distributed/__init__.py
FastVideo-main/fastvideo/v1/distributed/__init__.py
+20
-0
FastVideo-main/fastvideo/v1/distributed/communication_op.py
FastVideo-main/fastvideo/v1/distributed/communication_op.py
+32
-0
FastVideo-main/fastvideo/v1/distributed/device_communicators/__init__.py
...fastvideo/v1/distributed/device_communicators/__init__.py
+0
-0
FastVideo-main/fastvideo/v1/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+194
-0
FastVideo-main/fastvideo/v1/distributed/device_communicators/cuda_communicator.py
.../v1/distributed/device_communicators/cuda_communicator.py
+76
-0
FastVideo-main/fastvideo/v1/distributed/device_communicators/pynccl.py
...n/fastvideo/v1/distributed/device_communicators/pynccl.py
+218
-0
FastVideo-main/fastvideo/v1/distributed/device_communicators/pynccl_wrapper.py
...deo/v1/distributed/device_communicators/pynccl_wrapper.py
+341
-0
FastVideo-main/fastvideo/v1/distributed/parallel_state.py
FastVideo-main/fastvideo/v1/distributed/parallel_state.py
+1183
-0
FastVideo-main/fastvideo/v1/distributed/utils.py
FastVideo-main/fastvideo/v1/distributed/utils.py
+191
-0
FastVideo-main/fastvideo/v1/entrypoints/__init__.py
FastVideo-main/fastvideo/v1/entrypoints/__init__.py
+0
-0
FastVideo-main/fastvideo/v1/entrypoints/cli/__init__.py
FastVideo-main/fastvideo/v1/entrypoints/cli/__init__.py
+0
-0
FastVideo-main/fastvideo/v1/entrypoints/cli/cli_types.py
FastVideo-main/fastvideo/v1/entrypoints/cli/cli_types.py
+26
-0
FastVideo-main/fastvideo/v1/entrypoints/cli/generate.py
FastVideo-main/fastvideo/v1/entrypoints/cli/generate.py
+87
-0
FastVideo-main/fastvideo/v1/entrypoints/cli/main.py
FastVideo-main/fastvideo/v1/entrypoints/cli/main.py
+39
-0
FastVideo-main/fastvideo/v1/entrypoints/cli/utils.py
FastVideo-main/fastvideo/v1/entrypoints/cli/utils.py
+60
-0
FastVideo-main/fastvideo/v1/entrypoints/video_generator.py
FastVideo-main/fastvideo/v1/entrypoints/video_generator.py
+259
-0
No files found.
Too many changes to show.
To preserve performance only
270 of 270+
files are displayed.
Plain diff
Email patch
FastVideo-main/fastvideo/v1/configs/sample/hunyuan.py
0 → 100644
View file @
c07946d8
from
dataclasses
import
dataclass
from
fastvideo.v1.configs.sample.base
import
SamplingParam
@
dataclass
class
HunyuanSamplingParam
(
SamplingParam
):
num_inference_steps
:
int
=
50
num_frames
:
int
=
125
height
:
int
=
720
width
:
int
=
1280
fps
:
int
=
24
guidance_scale
:
float
=
1.0
@
dataclass
class
FastHunyuanSamplingParam
(
HunyuanSamplingParam
):
num_inference_steps
:
int
=
6
FastVideo-main/fastvideo/v1/configs/sample/registry.py
0 → 100644
View file @
c07946d8
import
os
from
typing
import
Any
,
Callable
,
Dict
,
Optional
from
fastvideo.v1.configs.sample.hunyuan
import
(
FastHunyuanSamplingParam
,
HunyuanSamplingParam
)
from
fastvideo.v1.configs.sample.wan
import
(
WanI2V480PSamplingParam
,
WanT2V480PSamplingParam
)
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.utils
import
(
maybe_download_model_index
,
verify_model_config_and_directory
)
logger
=
init_logger
(
__name__
)
# Registry maps specific model weights to their config classes
SAMPLING_PARAM_REGISTRY
:
Dict
[
str
,
Any
]
=
{
"FastVideo/FastHunyuan-diffusers"
:
FastHunyuanSamplingParam
,
"hunyuanvideo-community/HunyuanVideo"
:
HunyuanSamplingParam
,
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
:
WanT2V480PSamplingParam
,
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
:
WanI2V480PSamplingParam
# Add other specific weight variants
}
# For determining pipeline type from model ID
SAMPLING_PARAM_DETECTOR
:
Dict
[
str
,
Callable
[[
str
],
bool
]]
=
{
"hunyuan"
:
lambda
id
:
"hunyuan"
in
id
.
lower
(),
"wanpipeline"
:
lambda
id
:
"wanpipeline"
in
id
.
lower
(),
"wanimagetovideo"
:
lambda
id
:
"wanimagetovideo"
in
id
.
lower
(),
# Add other pipeline architecture detectors
}
# Fallback configs when exact match isn't found but architecture is detected
SAMPLING_FALLBACK_PARAM
:
Dict
[
str
,
Any
]
=
{
"hunyuan"
:
HunyuanSamplingParam
,
# Base Hunyuan config as fallback for any Hunyuan variant
"wanpipeline"
:
WanT2V480PSamplingParam
,
# Base Wan config as fallback for any Wan variant
"wanimagetovideo"
:
WanI2V480PSamplingParam
,
# Other fallbacks by architecture
}
def
get_sampling_param_cls_for_name
(
pipeline_name_or_path
:
str
)
->
Optional
[
Any
]:
"""Get the appropriate sampling param for specific pretrained weights."""
if
os
.
path
.
exists
(
pipeline_name_or_path
):
config
=
verify_model_config_and_directory
(
pipeline_name_or_path
)
logger
.
warning
(
"FastVideo may not correctly identify the optimal sampling param for this model, as the local directory may have been renamed."
)
else
:
config
=
maybe_download_model_index
(
pipeline_name_or_path
)
pipeline_name
=
config
[
"_class_name"
]
# First try exact match for specific weights
if
pipeline_name_or_path
in
SAMPLING_PARAM_REGISTRY
:
return
SAMPLING_PARAM_REGISTRY
[
pipeline_name_or_path
]
# Try partial matches (for local paths that might include the weight ID)
for
registered_id
,
config_class
in
SAMPLING_PARAM_REGISTRY
.
items
():
if
registered_id
in
pipeline_name_or_path
:
return
config_class
# If no match, try to use the fallback config
fallback_config
=
None
# Try to determine pipeline architecture for fallback
for
pipeline_type
,
detector
in
SAMPLING_PARAM_DETECTOR
.
items
():
if
detector
(
pipeline_name
.
lower
()):
fallback_config
=
SAMPLING_FALLBACK_PARAM
.
get
(
pipeline_type
)
break
logger
.
warning
(
"No match found for pipeline %s, using fallback sampling param %s."
,
pipeline_name_or_path
,
fallback_config
)
return
fallback_config
FastVideo-main/fastvideo/v1/configs/sample/wan.py
0 → 100644
View file @
c07946d8
from
dataclasses
import
dataclass
from
fastvideo.v1.configs.sample.base
import
SamplingParam
@
dataclass
class
WanT2V480PSamplingParam
(
SamplingParam
):
# Video parameters
height
:
int
=
480
width
:
int
=
832
num_frames
:
int
=
81
fps
:
int
=
16
# Denoising stage
guidance_scale
:
float
=
3.0
negative_prompt
:
str
=
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_inference_steps
:
int
=
50
@
dataclass
class
WanI2V480PSamplingParam
(
WanT2V480PSamplingParam
):
# Denoising stage
guidance_scale
:
float
=
5.0
num_inference_steps
:
int
=
40
FastVideo-main/fastvideo/v1/default_configs/v1_inference_hunyuan_config.yaml
0 → 100644
View file @
c07946d8
num_gpus
:
4
model_path
:
FastVideo/FastHunyuan-diffusers
master_port
:
29503
sp_size
:
4
tp_size
:
4
height
:
720
width
:
1280
num_frames
:
125
num_inference_steps
:
6
guidance_scale
:
1
embedded_cfg_scale
:
6
flow_shift
:
17
prompt_path
:
./assets/prompt.txt
seed
:
1024
output_path
:
outputs_video/
vae-sp
:
True
\ No newline at end of file
FastVideo-main/fastvideo/v1/distributed/__init__.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
from
fastvideo.v1.distributed.communication_op
import
*
from
fastvideo.v1.distributed.parallel_state
import
(
cleanup_dist_env_and_memory
,
get_sequence_model_parallel_rank
,
get_sequence_model_parallel_world_size
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_world_group
,
init_distributed_environment
,
initialize_model_parallel
)
from
fastvideo.v1.distributed.utils
import
*
__all__
=
[
"init_distributed_environment"
,
"initialize_model_parallel"
,
"get_sequence_model_parallel_rank"
,
"get_sequence_model_parallel_world_size"
,
"get_tensor_model_parallel_rank"
,
"get_tensor_model_parallel_world_size"
,
"cleanup_dist_env_and_memory"
,
"get_world_group"
,
]
FastVideo-main/fastvideo/v1/distributed/communication_op.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py
import
torch
import
torch.distributed
from
fastvideo.v1.distributed.parallel_state
import
get_sp_group
,
get_tp_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group."""
return
get_tp_group
().
all_reduce
(
input_
)
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
return
get_tp_group
().
all_gather
(
input_
,
dim
)
# TODO: remove model, make it sequence_parallel
def
sequence_model_parallel_all_to_all_4D
(
input_
:
torch
.
Tensor
,
scatter_dim
:
int
=
2
,
gather_dim
:
int
=
1
)
->
torch
.
Tensor
:
"""All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
return
get_sp_group
().
all_to_all_4D
(
input_
,
scatter_dim
,
gather_dim
)
def
sequence_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
return
get_sp_group
().
all_gather
(
input_
,
dim
)
FastVideo-main/fastvideo/v1/distributed/device_communicators/__init__.py
0 → 100644
View file @
c07946d8
FastVideo-main/fastvideo/v1/distributed/device_communicators/base_device_communicator.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
class
DeviceCommunicatorBase
:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def
__init__
(
self
,
cpu_group
:
ProcessGroup
,
device
:
Optional
[
torch
.
device
]
=
None
,
device_group
:
Optional
[
ProcessGroup
]
=
None
,
unique_name
:
str
=
""
):
self
.
device
=
device
or
torch
.
device
(
"cpu"
)
self
.
cpu_group
=
cpu_group
self
.
device_group
=
device_group
self
.
unique_name
=
unique_name
self
.
rank
=
dist
.
get_rank
(
cpu_group
)
self
.
world_size
=
dist
.
get_world_size
(
cpu_group
)
self
.
ranks
=
dist
.
get_process_group_ranks
(
cpu_group
)
self
.
global_rank
=
dist
.
get_rank
()
self
.
global_world_size
=
dist
.
get_world_size
()
self
.
rank_in_group
=
dist
.
get_group_rank
(
self
.
cpu_group
,
self
.
global_rank
)
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size
=
(
input_size
[
0
]
*
self
.
world_size
,
)
+
input_size
[
1
:]
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
dist
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
# Reshape
output_tensor
=
output_tensor
.
reshape
((
self
.
world_size
,
)
+
input_size
)
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
self
.
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
return
output_tensor
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size
=
self
.
world_size
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Allocate output tensor.
if
self
.
rank_in_group
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
device_group
)
if
self
.
rank_in_group
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
all_to_all_4D
(
self
,
input_
:
torch
.
Tensor
,
scatter_dim
:
int
=
2
,
gather_dim
:
int
=
1
)
->
torch
.
Tensor
:
"""Specialized all-to-all operation for 4D tensors (e.g., for QKV matrices).
Args:
input_ (torch.Tensor): 4D input tensor to be scattered and gathered.
scatter_dim (int, optional): Dimension along which to scatter. Defaults to 2.
gather_dim (int, optional): Dimension along which to gather. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-to-all operation.
"""
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
assert
input_
.
dim
(
)
==
4
,
f
"input must be 4D tensor, got
{
input_
.
dim
()
}
and shape
{
input_
.
shape
}
"
if
scatter_dim
==
2
and
gather_dim
==
1
:
# input: (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs
,
shard_seqlen
,
hc
,
hs
=
input_
.
shape
seqlen
=
shard_seqlen
*
self
.
world_size
shard_hc
=
hc
//
self
.
world_size
# Reshape and transpose for scattering
input_t
=
(
input_
.
reshape
(
bs
,
shard_seqlen
,
self
.
world_size
,
shard_hc
,
hs
).
transpose
(
0
,
2
).
contiguous
())
output
=
torch
.
empty_like
(
input_t
)
torch
.
distributed
.
all_to_all_single
(
output
,
input_t
,
group
=
self
.
device_group
)
torch
.
cuda
.
synchronize
()
# Reshape and transpose back
output
=
output
.
reshape
(
seqlen
,
bs
,
shard_hc
,
hs
).
transpose
(
0
,
1
).
contiguous
().
reshape
(
bs
,
seqlen
,
shard_hc
,
hs
)
return
output
elif
scatter_dim
==
1
and
gather_dim
==
2
:
# input: (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs
,
seqlen
,
shard_hc
,
hs
=
input_
.
shape
hc
=
shard_hc
*
self
.
world_size
shard_seqlen
=
seqlen
//
self
.
world_size
# Reshape and transpose for scattering
input_t
=
(
input_
.
reshape
(
bs
,
self
.
world_size
,
shard_seqlen
,
shard_hc
,
hs
).
transpose
(
0
,
3
).
transpose
(
0
,
1
).
contiguous
().
reshape
(
self
.
world_size
,
shard_hc
,
shard_seqlen
,
bs
,
hs
))
output
=
torch
.
empty_like
(
input_t
)
torch
.
distributed
.
all_to_all_single
(
output
,
input_t
,
group
=
self
.
device_group
)
torch
.
cuda
.
synchronize
()
# Reshape and transpose back
output
=
output
.
reshape
(
hc
,
shard_seqlen
,
bs
,
hs
).
transpose
(
0
,
2
).
contiguous
().
reshape
(
bs
,
shard_seqlen
,
hc
,
hs
)
return
output
else
:
raise
RuntimeError
(
"scatter_dim must be 1 or 2 and gather_dim must be 1 or 2"
)
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if
dst
is
None
:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
torch
.
distributed
.
send
(
tensor
,
self
.
ranks
[
dst
],
self
.
device_group
)
def
recv
(
self
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
torch
.
distributed
.
recv
(
tensor
,
self
.
ranks
[
src
],
self
.
device_group
)
return
tensor
def
destroy
(
self
)
->
None
:
pass
FastVideo-main/fastvideo/v1/distributed/device_communicators/cuda_communicator.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/cuda_communicator.py
from
typing
import
Optional
import
torch
from
torch.distributed
import
ProcessGroup
from
fastvideo.v1.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
)
class
CudaCommunicator
(
DeviceCommunicatorBase
):
def
__init__
(
self
,
cpu_group
:
ProcessGroup
,
device
:
Optional
[
torch
.
device
]
=
None
,
device_group
:
Optional
[
ProcessGroup
]
=
None
,
unique_name
:
str
=
""
):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
from
fastvideo.v1.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
)
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
if
self
.
world_size
>
1
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
def
all_reduce
(
self
,
input_
):
pynccl_comm
=
self
.
pynccl_comm
assert
pynccl_comm
is
not
None
out
=
pynccl_comm
.
all_reduce
(
input_
)
if
out
is
None
:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out
=
input_
.
clone
()
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
device_group
)
return
out
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if
dst
is
None
:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
send
(
tensor
,
dst
)
else
:
torch
.
distributed
.
send
(
tensor
,
self
.
ranks
[
dst
],
self
.
device_group
)
def
recv
(
self
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
recv
(
tensor
,
src
)
else
:
torch
.
distributed
.
recv
(
tensor
,
self
.
ranks
[
src
],
self
.
device_group
)
return
tensor
def
destroy
(
self
)
->
None
:
if
self
.
pynccl_comm
is
not
None
:
self
.
pynccl_comm
=
None
FastVideo-main/fastvideo/v1/distributed/device_communicators/pynccl.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl.py
from
typing
import
Optional
,
Union
# ===================== import region =====================
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
fastvideo.v1.distributed.device_communicators.pynccl_wrapper
import
(
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
,
ncclRedOpTypeEnum
,
ncclUniqueId
)
from
fastvideo.v1.distributed.utils
import
StatelessProcessGroup
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.utils
import
current_stream
logger
=
init_logger
(
__name__
)
class
PyNcclCommunicator
:
def
__init__
(
self
,
group
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
device
:
Union
[
int
,
str
,
torch
.
device
],
library_path
:
Optional
[
str
]
=
None
,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if
not
isinstance
(
group
,
StatelessProcessGroup
):
assert
dist
.
is_initialized
()
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"PyNcclCommunicator should be attached to a non-NCCL group."
)
# note: this rank is the rank in the group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
else
:
self
.
rank
=
group
.
rank
self
.
world_size
=
group
.
world_size
self
.
group
=
group
# if world_size == 1, no need to create communicator
if
self
.
world_size
==
1
:
self
.
available
=
False
self
.
disabled
=
True
return
try
:
self
.
nccl
=
NCCLLibrary
(
library_path
)
except
Exception
:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self
.
available
=
False
self
.
disabled
=
True
return
self
.
available
=
True
self
.
disabled
=
False
logger
.
info
(
"FastVideo is using nccl==%s"
,
self
.
nccl
.
ncclGetVersion
())
if
self
.
rank
==
0
:
# get the unique id from NCCL
self
.
unique_id
=
self
.
nccl
.
ncclGetUniqueId
()
else
:
# construct an empty unique id
self
.
unique_id
=
ncclUniqueId
()
if
not
isinstance
(
group
,
StatelessProcessGroup
):
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
ranks
=
dist
.
get_process_group_ranks
(
group
)
# arg `src` in `broadcast` is the global rank
dist
.
broadcast
(
tensor
,
src
=
ranks
[
0
],
group
=
group
)
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
else
:
self
.
unique_id
=
group
.
broadcast_obj
(
self
.
unique_id
,
src
=
0
)
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with
torch
.
cuda
.
device
(
device
):
self
.
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
stream
=
current_stream
()
# A small all_reduce for warmup.
data
=
torch
.
zeros
(
1
,
device
=
device
)
self
.
all_reduce
(
data
)
stream
.
synchronize
()
del
data
def
all_reduce
(
self
,
in_tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
)
->
torch
.
Tensor
:
if
self
.
disabled
:
return
None
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
in_tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
in_tensor
.
device
}
"
)
out_tensor
=
torch
.
empty_like
(
in_tensor
)
if
stream
is
None
:
stream
=
current_stream
()
self
.
nccl
.
ncclAllReduce
(
buffer_type
(
in_tensor
.
data_ptr
()),
buffer_type
(
out_tensor
.
data_ptr
()),
in_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
in_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
return
out_tensor
def
all_gather
(
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
stream
=
None
):
if
self
.
disabled
:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
input_tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
current_stream
()
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
input_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
reduce_scatter
(
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
if
self
.
disabled
:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
input_tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
input_tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
current_stream
()
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
output_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
current_stream
()
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
dst
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
recv
(
self
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
current_stream
()
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
def
broadcast
(
self
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
current_stream
()
if
src
==
self
.
rank
:
sendbuff
=
buffer_type
(
tensor
.
data_ptr
())
# NCCL requires the sender also to have a receive buffer
recvbuff
=
buffer_type
(
tensor
.
data_ptr
())
else
:
sendbuff
=
buffer_type
()
recvbuff
=
buffer_type
(
tensor
.
data_ptr
())
self
.
nccl
.
ncclBroadcast
(
sendbuff
,
recvbuff
,
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
))
FastVideo-main/fastvideo/v1/distributed/device_communicators/pynccl_wrapper.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl_wrapper.py
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `FASTVIDEO_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
#TODO(will): support FASTVIDEO_NCCL_SO_PATH
import
ctypes
import
platform
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.distributed
import
ReduceOp
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.utils
import
find_nccl_library
logger
=
init_logger
(
__name__
)
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t
=
ctypes
.
c_int
ncclComm_t
=
ctypes
.
c_void_p
class
ncclUniqueId
(
ctypes
.
Structure
):
_fields_
=
[(
"internal"
,
ctypes
.
c_byte
*
128
)]
cudaStream_t
=
ctypes
.
c_void_p
buffer_type
=
ctypes
.
c_void_p
ncclDataType_t
=
ctypes
.
c_int
class
ncclDataTypeEnum
:
ncclInt8
=
0
ncclChar
=
0
ncclUint8
=
1
ncclInt32
=
2
ncclInt
=
2
ncclUint32
=
3
ncclInt64
=
4
ncclUint64
=
5
ncclFloat16
=
6
ncclHalf
=
6
ncclFloat32
=
7
ncclFloat
=
7
ncclFloat64
=
8
ncclDouble
=
8
ncclBfloat16
=
9
ncclNumTypes
=
10
@
classmethod
def
from_torch
(
cls
,
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
int8
:
return
cls
.
ncclInt8
if
dtype
==
torch
.
uint8
:
return
cls
.
ncclUint8
if
dtype
==
torch
.
int32
:
return
cls
.
ncclInt32
if
dtype
==
torch
.
int64
:
return
cls
.
ncclInt64
if
dtype
==
torch
.
float16
:
return
cls
.
ncclFloat16
if
dtype
==
torch
.
float32
:
return
cls
.
ncclFloat32
if
dtype
==
torch
.
float64
:
return
cls
.
ncclFloat64
if
dtype
==
torch
.
bfloat16
:
return
cls
.
ncclBfloat16
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
ncclRedOp_t
=
ctypes
.
c_int
class
ncclRedOpTypeEnum
:
ncclSum
=
0
ncclProd
=
1
ncclMax
=
2
ncclMin
=
3
ncclAvg
=
4
ncclNumOps
=
5
@
classmethod
def
from_torch
(
cls
,
op
:
ReduceOp
)
->
int
:
if
op
==
ReduceOp
.
SUM
:
return
cls
.
ncclSum
if
op
==
ReduceOp
.
PRODUCT
:
return
cls
.
ncclProd
if
op
==
ReduceOp
.
MAX
:
return
cls
.
ncclMax
if
op
==
ReduceOp
.
MIN
:
return
cls
.
ncclMin
if
op
==
ReduceOp
.
AVG
:
return
cls
.
ncclAvg
raise
ValueError
(
f
"Unsupported op:
{
op
}
"
)
@
dataclass
class
Function
:
name
:
str
restype
:
Any
argtypes
:
List
[
Any
]
class
NCCLLibrary
:
exported_functions
=
[
# const char* ncclGetErrorString(ncclResult_t result)
Function
(
"ncclGetErrorString"
,
ctypes
.
c_char_p
,
[
ncclResult_t
]),
# ncclResult_t ncclGetVersion(int *version);
Function
(
"ncclGetVersion"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ctypes
.
c_int
)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function
(
"ncclGetUniqueId"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ncclUniqueId
)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function
(
"ncclCommInitRank"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ncclComm_t
),
ctypes
.
c_int
,
ncclUniqueId
,
ctypes
.
c_int
]),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclAllReduce"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclAllGather"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclReduceScatter"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function
(
"ncclSend"
,
ncclResult_t
,
[
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function
(
"ncclRecv"
,
ncclResult_t
,
[
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function
(
"ncclBroadcast"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function
(
"ncclCommDestroy"
,
ncclResult_t
,
[
ncclComm_t
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache
:
Dict
[
str
,
Any
]
=
{}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{}
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
so_file
=
so_file
or
find_nccl_library
()
try
:
if
so_file
not
in
NCCLLibrary
.
path_to_dict_mapping
:
lib
=
ctypes
.
CDLL
(
so_file
)
NCCLLibrary
.
path_to_library_cache
[
so_file
]
=
lib
self
.
lib
=
NCCLLibrary
.
path_to_library_cache
[
so_file
]
except
Exception
as
e
:
logger
.
error
(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
"environment variable FASTVIDEO_NCCL_SO_PATH"
" to point to the correct nccl library path."
,
so_file
,
platform
.
platform
())
raise
e
if
so_file
not
in
NCCLLibrary
.
path_to_dict_mapping
:
_funcs
:
Dict
[
str
,
Any
]
=
{}
for
func
in
NCCLLibrary
.
exported_functions
:
f
=
getattr
(
self
.
lib
,
func
.
name
)
f
.
restype
=
func
.
restype
f
.
argtypes
=
func
.
argtypes
_funcs
[
func
.
name
]
=
f
NCCLLibrary
.
path_to_dict_mapping
[
so_file
]
=
_funcs
self
.
_funcs
=
NCCLLibrary
.
path_to_dict_mapping
[
so_file
]
def
ncclGetErrorString
(
self
,
result
:
ncclResult_t
)
->
str
:
return
str
(
self
.
_funcs
[
"ncclGetErrorString"
](
result
).
decode
(
"utf-8"
))
def
NCCL_CHECK
(
self
,
result
:
ncclResult_t
)
->
None
:
if
result
!=
0
:
error_str
=
self
.
ncclGetErrorString
(
result
)
raise
RuntimeError
(
f
"NCCL error:
{
error_str
}
"
)
def
ncclGetVersion
(
self
)
->
str
:
version
=
ctypes
.
c_int
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGetVersion"
](
ctypes
.
byref
(
version
)))
version_str
=
str
(
version
.
value
)
# something like 21903 --> "2.19.3"
major
=
version_str
[
0
].
lstrip
(
"0"
)
minor
=
version_str
[
1
:
3
].
lstrip
(
"0"
)
patch
=
version_str
[
3
:].
lstrip
(
"0"
)
return
f
"
{
major
}
.
{
minor
}
.
{
patch
}
"
def
ncclGetUniqueId
(
self
)
->
ncclUniqueId
:
unique_id
=
ncclUniqueId
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGetUniqueId"
](
ctypes
.
byref
(
unique_id
)))
return
unique_id
def
ncclCommInitRank
(
self
,
world_size
:
int
,
unique_id
:
ncclUniqueId
,
rank
:
int
)
->
ncclComm_t
:
comm
=
ncclComm_t
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommInitRank"
](
ctypes
.
byref
(
comm
),
world_size
,
unique_id
,
rank
))
return
comm
def
ncclAllReduce
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
op
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclAllReduce"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
comm
,
stream
))
def
ncclReduceScatter
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
op
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclReduceScatter"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
comm
,
stream
))
def
ncclAllGather
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclAllGather"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
comm
,
stream
))
def
ncclSend
(
self
,
sendbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
dest
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclSend"
](
sendbuff
,
count
,
datatype
,
dest
,
comm
,
stream
))
def
ncclRecv
(
self
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
src
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclRecv"
](
recvbuff
,
count
,
datatype
,
src
,
comm
,
stream
))
def
ncclBroadcast
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
root
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclBroadcast"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
root
,
comm
,
stream
))
def
ncclCommDestroy
(
self
,
comm
:
ncclComm_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommDestroy"
](
comm
))
__all__
=
[
"NCCLLibrary"
,
"ncclDataTypeEnum"
,
"ncclRedOpTypeEnum"
,
"ncclUniqueId"
,
"ncclComm_t"
,
"cudaStream_t"
,
"buffer_type"
]
FastVideo-main/fastvideo/v1/distributed/parallel_state.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Adapted from
"""FastVideo distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.
If you only need to use the distributed environment without model parallelism,
you can skip the model parallel initialization and destruction steps.
"""
import
contextlib
import
gc
import
pickle
import
weakref
from
collections
import
namedtuple
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
import
torch
import
torch.distributed
from
torch.distributed
import
Backend
,
ProcessGroup
import
fastvideo.v1.envs
as
envs
from
fastvideo.v1.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
)
from
fastvideo.v1.distributed.device_communicators.cuda_communicator
import
(
CudaCommunicator
)
from
fastvideo.v1.distributed.utils
import
StatelessProcessGroup
from
fastvideo.v1.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list
:
List
[
Tuple
[
str
,
Any
]]
=
[]
tensor_list
:
List
[
torch
.
Tensor
]
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device
=
value
.
device
.
type
metadata_list
.
append
(
(
key
,
TensorMetadata
(
device
,
value
.
dtype
,
value
.
size
())))
tensor_list
.
append
(
value
)
else
:
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
_group_name_counter
:
Dict
[
str
,
int
]
=
{}
def
_get_unique_name
(
name
:
str
)
->
str
:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if
name
not
in
_group_name_counter
:
_group_name_counter
[
name
]
=
0
newname
=
f
"
{
name
}
:
{
_group_name_counter
[
name
]
}
"
_group_name_counter
[
name
]
+=
1
return
newname
_groups
:
Dict
[
str
,
Callable
[[],
Optional
[
"GroupCoordinator"
]]]
=
{}
def
_register_group
(
group
:
"GroupCoordinator"
)
->
None
:
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
def
all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce_out_place
(
tensor
)
def
all_reduce_fake
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
class
GroupCoordinator
:
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It manages both CPU and device
communication.
"""
# available attributes:
rank
:
int
# global rank
ranks
:
List
[
int
]
# global ranks in the group
world_size
:
int
# size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
# Process | Node | Rank | Local Rank | Rank in Group
# 0 | 0 | 0 | 0 | 0
# 1 | 0 | 1 | 1 | 1
# 2 | 1 | 2 | 0 | 2
# 3 | 1 | 3 | 1 | 3
local_rank
:
int
# local rank used to assign devices
rank_in_group
:
int
# rank inside the group
cpu_group
:
ProcessGroup
# group for CPU communication
device_group
:
ProcessGroup
# group for device communication
use_device_communicator
:
bool
# whether to use device communicator
device_communicator
:
DeviceCommunicatorBase
# device communicator
mq_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
def
__init__
(
self
,
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_device_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
group_name
=
group_name
or
"anonymous"
self
.
unique_name
=
_get_unique_name
(
group_name
)
_register_group
(
self
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
local_rank
=
local_rank
self
.
device_group
=
None
self
.
cpu_group
=
None
for
ranks
in
group_ranks
:
device_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
self
.
rank
in
ranks
:
self
.
ranks
=
ranks
self
.
world_size
=
len
(
ranks
)
self
.
rank_in_group
=
ranks
.
index
(
self
.
rank
)
self
.
device_group
=
device_group
self
.
cpu_group
=
cpu_group
assert
self
.
cpu_group
is
not
None
assert
self
.
device_group
is
not
None
from
fastvideo.v1.platforms
import
current_platform
# TODO: fix it for other platforms
if
current_platform
.
is_cuda_alike
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
use_device_communicator
=
use_device_communicator
self
.
device_communicator
:
DeviceCommunicatorBase
=
None
# type: ignore
if
use_device_communicator
and
self
.
world_size
>
1
:
# device_comm_cls = resolve_obj_by_qualname(
# current_platform.get_device_communicator_cls())
self
.
device_communicator
=
CudaCommunicator
(
cpu_group
=
self
.
cpu_group
,
device
=
self
.
device
,
device_group
=
self
.
device_group
,
unique_name
=
self
.
unique_name
,
)
self
.
mq_broadcaster
=
None
from
fastvideo.v1.platforms
import
current_platform
# TODO(will): check if this is needed
# self.use_custom_op_call = current_platform.is_cuda_alike()
self
.
use_custom_op_call
=
False
@
property
def
first_rank
(
self
):
"""Return the global rank of the first process in the group"""
return
self
.
ranks
[
0
]
@
property
def
last_rank
(
self
):
"""Return the global rank of the last process in the group"""
return
self
.
ranks
[
-
1
]
@
property
def
is_first_rank
(
self
):
"""Return whether the caller is the first process in the group"""
return
self
.
rank
==
self
.
first_rank
@
property
def
is_last_rank
(
self
):
"""Return whether the caller is the last process in the group"""
return
self
.
rank
==
self
.
last_rank
@
property
def
next_rank
(
self
):
"""Return the global rank of the process that follows the caller"""
rank_in_group
=
self
.
rank_in_group
world_size
=
self
.
world_size
return
self
.
ranks
[(
rank_in_group
+
1
)
%
world_size
]
@
property
def
prev_rank
(
self
):
"""Return the global rank of the process that precedes the caller"""
rank_in_group
=
self
.
rank_in_group
world_size
=
self
.
world_size
return
self
.
ranks
[(
rank_in_group
-
1
)
%
world_size
]
@
contextmanager
def
graph_capture
(
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
):
if
graph_capture_context
is
None
:
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
else
:
stream
=
graph_capture_context
.
stream
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream
=
torch
.
cuda
.
current_stream
()
if
curr_stream
!=
stream
:
stream
.
wait_stream
(
curr_stream
)
with
torch
.
cuda
.
stream
(
stream
):
yield
graph_capture_context
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we always make the all-reduce operation
out-of-place.
"""
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
if
self
.
use_custom_op_call
:
return
torch
.
ops
.
vllm
.
all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
else
:
return
self
.
_all_reduce_out_place
(
input_
)
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
device_communicator
.
all_reduce
(
input_
)
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
return
self
.
device_communicator
.
all_gather
(
input_
,
dim
)
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
return
self
.
device_communicator
.
gather
(
input_
,
dst
,
dim
)
def
all_to_all_4D
(
self
,
input_
:
torch
.
Tensor
,
scatter_dim
:
int
=
2
,
gather_dim
:
int
=
1
)
->
torch
.
Tensor
:
if
self
.
world_size
==
1
:
return
input_
return
self
.
device_communicator
.
all_to_all_4D
(
input_
,
scatter_dim
,
gather_dim
)
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
"""Broadcast the input tensor.
NOTE: `src` is the local rank of the source rank.
"""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
self
.
ranks
[
src
],
group
=
self
.
device_group
)
return
input_
def
broadcast_object
(
self
,
obj
:
Optional
[
Any
]
=
None
,
src
:
int
=
0
):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
obj
if
self
.
mq_broadcaster
is
not
None
:
assert
src
==
0
,
"Message queue broadcaster only supports src=0"
return
self
.
mq_broadcaster
.
broadcast_object
(
obj
)
if
self
.
rank_in_group
==
src
:
torch
.
distributed
.
broadcast_object_list
([
obj
],
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
return
obj
else
:
recv
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
return
recv
[
0
]
def
broadcast_object_list
(
self
,
obj_list
:
List
[
Any
],
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input object list.
NOTE: `src` is the local rank of the source rank.
"""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
self
.
ranks
[
src
],
group
=
self
.
device_group
)
return
obj_list
def
send_object
(
self
,
obj
:
Any
,
dst
:
int
)
->
None
:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
assert
dst
!=
self
.
rank_in_group
,
(
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
# Serialize object to tensor and get the size as well
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
)
size_tensor
=
torch
.
tensor
([
object_tensor
.
numel
()],
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Send object size
torch
.
distributed
.
send
(
size_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
)
# Send object
torch
.
distributed
.
send
(
object_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
)
return
None
def
recv_object
(
self
,
src
:
int
)
->
Any
:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
src
!=
self
.
rank_in_group
,
(
"Invalid source rank. Source rank is the same as the current rank."
)
size_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Receive object size
rank_size
=
torch
.
distributed
.
recv
(
size_tensor
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
# Tensor to receive serialized objects into.
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
size_tensor
.
item
(),
# type: ignore[arg-type]
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
rank_object
=
torch
.
distributed
.
recv
(
object_tensor
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
assert
rank_object
==
rank_size
,
(
"Received object sender rank does not match the size sender rank."
)
obj
=
pickle
.
loads
(
object_tensor
.
numpy
().
tobytes
())
return
obj
def
broadcast_tensor_dict
(
self
,
tensor_dict
:
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
metadata_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
(
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
):
return
tensor_dict
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
rank_in_group
=
self
.
rank_in_group
if
rank_in_group
==
src
:
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
assert
isinstance
(
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self
.
broadcast_object
(
metadata_list
,
src
=
src
)
async_handles
=
[]
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
metadata_group
,
async_op
=
True
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
for
async_handle
in
async_handles
:
async_handle
.
wait
()
else
:
metadata_list
=
self
.
broadcast_object
(
None
,
src
=
src
)
tensor_dict
=
{}
async_handles
=
[]
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
tensor_dict
[
key
]
=
tensor
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
metadata_group
,
async_op
=
True
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
async_handle
.
wait
()
return
tensor_dict
def
send_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
dst
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
tensor_dict
all_gather_size
=
(
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
)
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
if
dst
is
None
:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
assert
isinstance
(
tensor_dict
,
dict
),
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self
.
send_object
(
metadata_list
,
dst
=
dst
)
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if
(
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
):
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
group
)
return
None
def
recv_tensor_dict
(
self
,
src
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
None
all_gather_size
=
(
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
)
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
recv_metadata_list
=
self
.
recv_object
(
src
=
src
)
tensor_dict
:
Dict
[
str
,
Any
]
=
{}
for
key
,
value
in
recv_metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
tensor_dict
[
key
]
=
tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather
=
(
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
)
if
use_all_gather
:
orig_shape
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
)
if
use_all_gather
:
# do the allgather
tensor
=
all_gather_group
.
all_gather
(
# type: ignore
tensor
,
dim
=
0
)
tensor
=
tensor
.
reshape
(
orig_shape
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
return
tensor_dict
def
barrier
(
self
):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
"""
torch
.
distributed
.
barrier
(
group
=
self
.
cpu_group
)
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
self
.
device_communicator
.
send
(
tensor
,
dst
)
def
recv
(
self
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
return
self
.
device_communicator
.
recv
(
size
,
dtype
,
src
)
def
destroy
(
self
)
->
None
:
if
self
.
device_group
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
self
.
device_group
)
self
.
device_group
=
None
if
self
.
cpu_group
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
self
.
cpu_group
)
self
.
cpu_group
=
None
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
destroy
()
if
self
.
mq_broadcaster
is
not
None
:
self
.
mq_broadcaster
=
None
_WORLD
:
Optional
[
GroupCoordinator
]
=
None
def
get_world_group
()
->
GroupCoordinator
:
assert
_WORLD
is
not
None
,
(
"world group is not initialized"
)
return
_WORLD
def
init_world_group
(
ranks
:
List
[
int
],
local_rank
:
int
,
backend
:
str
)
->
GroupCoordinator
:
return
GroupCoordinator
(
group_ranks
=
[
ranks
],
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_device_communicator
=
False
,
group_name
=
"world"
,
)
def
init_model_parallel_group
(
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
backend
:
str
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
)
->
GroupCoordinator
:
return
GroupCoordinator
(
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_device_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
)
_TP
:
Optional
[
GroupCoordinator
]
=
None
def
get_tp_group
()
->
GroupCoordinator
:
assert
_TP
is
not
None
,
(
"tensor model parallel group is not initialized"
)
return
_TP
# kept for backward compatibility
get_tensor_model_parallel_group
=
get_tp_group
_ENABLE_CUSTOM_ALL_REDUCE
=
True
def
set_custom_all_reduce
(
enable
:
bool
):
global
_ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
def
init_distributed_environment
(
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
distributed_init_method
:
str
=
"env://"
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
):
logger
.
debug
(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s"
,
world_size
,
rank
,
local_rank
,
distributed_init_method
,
backend
)
if
not
torch
.
distributed
.
is_initialized
():
assert
distributed_init_method
is
not
None
,
(
"distributed_init_method must be provided when initializing "
"distributed environment"
)
# this backend is used for WORLD
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
init_method
=
distributed_init_method
,
world_size
=
world_size
,
rank
=
rank
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if
local_rank
==
-
1
:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if
distributed_init_method
==
"env://"
:
local_rank
=
envs
.
LOCAL_RANK
else
:
local_rank
=
rank
global
_WORLD
if
_WORLD
is
None
:
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
_WORLD
=
init_world_group
(
ranks
,
local_rank
,
backend
)
else
:
assert
_WORLD
.
world_size
==
torch
.
distributed
.
get_world_size
(),
(
"world group already initialized with a different world size"
)
_SP
:
Optional
[
GroupCoordinator
]
=
None
def
get_sp_group
()
->
GroupCoordinator
:
assert
_SP
is
not
None
,
(
"sequence model parallel group is not initialized"
)
return
_SP
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
sequence_model_parallel_size
:
int
=
1
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
sequence_model_parallel_size: number of GPUs used for sequence model
parallelism.
"""
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
num_tensor_model_parallel_groups
:
int
=
(
world_size
//
tensor_model_parallel_size
)
global
_TP
assert
_TP
is
None
,
(
"tensor model parallel group is already initialized"
)
group_ranks
=
[]
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
list
(
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
))
group_ranks
.
append
(
ranks
)
# message queue broadcaster is only used in tensor model parallel group
_TP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
True
,
group_name
=
"tp"
)
# Build the sequence model-parallel groups.
num_sequence_model_parallel_groups
:
int
=
(
world_size
//
sequence_model_parallel_size
)
global
_SP
assert
_SP
is
None
,
(
"sequence model parallel group is already initialized"
)
group_ranks
=
[]
# Since SP is incompatible with TP and PP, we can use a simpler group creation logic
for
i
in
range
(
num_sequence_model_parallel_groups
):
# Create groups of consecutive ranks
ranks
=
list
(
range
(
i
*
sequence_model_parallel_size
,
(
i
+
1
)
*
sequence_model_parallel_size
))
group_ranks
.
append
(
ranks
)
_SP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
group_name
=
"sp"
)
def
get_sequence_model_parallel_world_size
()
->
int
:
"""Return world size for the sequence model parallel group."""
return
get_sp_group
().
world_size
def
get_sequence_model_parallel_rank
()
->
int
:
"""Return my rank for the sequence model parallel group."""
return
get_sp_group
().
rank_in_group
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
sequence_model_parallel_size
:
int
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel, sequence-parallel sizes
are equal to expected values if the model parallel groups are initialized.
"""
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
sequence_model_parallel_size
,
backend
)
return
assert
(
get_tensor_model_parallel_world_size
()
==
tensor_model_parallel_size
),
(
"tensor parallel group already initialized, but of unexpected size: "
f
"
{
get_tensor_model_parallel_world_size
()
=
}
vs. "
f
"
{
tensor_model_parallel_size
=
}
"
)
if
sequence_model_parallel_size
>
1
:
sp_world_size
=
get_sp_group
().
world_size
assert
(
sp_world_size
==
sequence_model_parallel_size
),
(
"sequence parallel group already initialized, but of unexpected size: "
f
"
{
sp_world_size
=
}
vs. "
f
"
{
sequence_model_parallel_size
=
}
"
)
def
model_parallel_is_initialized
()
->
bool
:
"""Check if tensor, sequence parallel groups are initialized."""
return
_TP
is
not
None
and
_SP
is
not
None
_TP_STATE_PATCHED
=
False
@
contextmanager
def
patch_tensor_parallel_group
(
tp_group
:
GroupCoordinator
):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global
_TP_STATE_PATCHED
assert
not
_TP_STATE_PATCHED
,
"Should not call when it's already patched"
_TP_STATE_PATCHED
=
True
old_tp_group
=
get_tp_group
()
global
_TP
_TP
=
tp_group
try
:
yield
finally
:
# restore the original state
_TP_STATE_PATCHED
=
False
_TP
=
old_tp_group
def
get_tensor_model_parallel_world_size
()
->
int
:
"""Return world size for the tensor model parallel group."""
return
get_tp_group
().
world_size
def
get_tensor_model_parallel_rank
()
->
int
:
"""Return my rank for the tensor model parallel group."""
return
get_tp_group
().
rank_in_group
def
destroy_model_parallel
()
->
None
:
"""Set the groups to none and destroy them."""
global
_TP
if
_TP
:
_TP
.
destroy
()
_TP
=
None
global
_SP
if
_SP
:
_SP
.
destroy
()
_SP
=
None
def
destroy_distributed_environment
()
->
None
:
global
_WORLD
if
_WORLD
:
_WORLD
.
destroy
()
_WORLD
=
None
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
destroy_process_group
()
def
cleanup_dist_env_and_memory
(
shutdown_ray
:
bool
=
False
):
destroy_model_parallel
()
destroy_distributed_environment
()
with
contextlib
.
suppress
(
AssertionError
):
torch
.
distributed
.
destroy_process_group
()
if
shutdown_ray
:
import
ray
# Lazy import Ray
ray
.
shutdown
()
gc
.
collect
()
from
fastvideo.v1.platforms
import
current_platform
if
not
current_platform
.
is_cpu
():
torch
.
cuda
.
empty_cache
()
try
:
torch
.
_C
.
_host_emptyCache
()
except
AttributeError
:
logger
.
warning
(
"torch._C._host_emptyCache() only available in Pytorch >=2.5"
)
def
in_the_same_node_as
(
pg
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
source_rank
:
int
=
0
)
->
List
[
bool
]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
if
isinstance
(
pg
,
ProcessGroup
):
assert
torch
.
distributed
.
get_backend
(
pg
)
!=
torch
.
distributed
.
Backend
.
NCCL
,
(
"in_the_same_node_as should be tested with a non-NCCL group."
)
# local rank inside the group
rank
=
torch
.
distributed
.
get_rank
(
group
=
pg
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
pg
)
# global ranks of the processes in the group
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
pg
)
else
:
rank
=
pg
.
rank
world_size
=
pg
.
world_size
ranks
=
list
(
range
(
world_size
))
# local tensor in each process to store the result
is_in_the_same_node
=
torch
.
tensor
([
0
]
*
world_size
,
dtype
=
torch
.
int32
)
magic_message
=
b
"magic_message"
shm
=
None
try
:
with
contextlib
.
suppress
(
OSError
):
if
rank
==
source_rank
:
# create a shared memory segment
shm
=
shared_memory
.
SharedMemory
(
create
=
True
,
size
=
128
)
shm
.
buf
[:
len
(
magic_message
)]
=
magic_message
if
isinstance
(
pg
,
ProcessGroup
):
torch
.
distributed
.
broadcast_object_list
(
[
shm
.
name
],
src
=
ranks
[
source_rank
],
group
=
pg
)
else
:
pg
.
broadcast_obj
(
shm
.
name
,
src
=
source_rank
)
is_in_the_same_node
[
rank
]
=
1
else
:
# try to open the shared memory segment
if
isinstance
(
pg
,
ProcessGroup
):
recv
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv
,
src
=
ranks
[
source_rank
],
group
=
pg
)
name
=
recv
[
0
]
else
:
name
=
pg
.
broadcast_obj
(
None
,
src
=
source_rank
)
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with
patch
(
"multiprocessing.resource_tracker.register"
,
lambda
*
args
,
**
kwargs
:
None
):
shm
=
shared_memory
.
SharedMemory
(
name
=
name
)
if
shm
.
buf
[:
len
(
magic_message
)]
==
magic_message
:
is_in_the_same_node
[
rank
]
=
1
except
Exception
as
e
:
logger
.
error
(
"Error ignored in is_in_the_same_node: %s"
,
e
)
finally
:
if
shm
:
shm
.
close
()
if
isinstance
(
pg
,
ProcessGroup
):
torch
.
distributed
.
barrier
(
group
=
pg
)
else
:
pg
.
barrier
()
# clean up the shared memory segment
with
contextlib
.
suppress
(
OSError
):
if
rank
==
source_rank
and
shm
:
shm
.
unlink
()
if
isinstance
(
pg
,
ProcessGroup
):
torch
.
distributed
.
all_reduce
(
is_in_the_same_node
,
group
=
pg
)
aggregated_data
=
is_in_the_same_node
else
:
aggregated_data
=
torch
.
zeros_like
(
is_in_the_same_node
)
for
i
in
range
(
world_size
):
rank_data
=
pg
.
broadcast_obj
(
is_in_the_same_node
,
src
=
i
)
aggregated_data
+=
rank_data
return
[
x
==
1
for
x
in
aggregated_data
.
tolist
()]
def
initialize_tensor_parallel_group
(
tensor_model_parallel_size
:
int
=
1
,
backend
:
Optional
[
str
]
=
None
,
group_name_suffix
:
str
=
""
)
->
GroupCoordinator
:
"""Initialize a tensor parallel group for a specific model.
This function creates a tensor parallel group that can be used with the
patch_tensor_parallel_group context manager. It allows different models
to use different tensor parallelism configurations.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
backend: communication backend to use.
group_name_suffix: optional suffix to make the group name unique.
Returns:
A GroupCoordinator for tensor parallelism that can be used with
the patch_tensor_parallel_group context manager.
Example usage:
```python
# Initialize tensor parallel group for model1
tp_group_model1 = initialize_tensor_parallel_group(
tensor_model_parallel_size=4,
group_name_suffix="model1"
)
# Use tensor parallelism for model1
with patch_tensor_parallel_group(tp_group_model1):
# Run model1 with tensor parallelism
output1 = model1(input1)
```
"""
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
# Ensure the world size is compatible with the parallelism configuration
assert
world_size
%
tensor_model_parallel_size
==
0
,
\
f
"World size (
{
world_size
}
) must be divisible by tensor_model_parallel_size (
{
tensor_model_parallel_size
}
)"
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups
:
int
=
(
world_size
//
tensor_model_parallel_size
)
tp_group_ranks
=
[]
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
list
(
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
))
tp_group_ranks
.
append
(
ranks
)
# Create TP group coordinator with a unique name
group_name
=
f
"tp_
{
group_name_suffix
}
"
if
group_name_suffix
else
"tp"
tp_group
=
init_model_parallel_group
(
tp_group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
True
,
group_name
=
group_name
)
return
tp_group
def
initialize_sequence_parallel_group
(
sequence_model_parallel_size
:
int
=
1
,
backend
:
Optional
[
str
]
=
None
,
group_name_suffix
:
str
=
""
)
->
GroupCoordinator
:
"""Initialize a sequence parallel group for a specific model.
This function creates a sequence parallel group that can be used with the
patch_sequence_parallel_group context manager. It allows different models
to use different sequence parallelism configurations.
Arguments:
sequence_model_parallel_size: number of GPUs used for sequence model parallelism.
backend: communication backend to use.
group_name_suffix: optional suffix to make the group name unique.
Returns:
A GroupCoordinator for sequence parallelism that can be used with
the patch_sequence_parallel_group context manager.
Example usage:
```python
# Initialize sequence parallel group for model2
sp_group_model2 = initialize_sequence_parallel_group(
sequence_model_parallel_size=2,
group_name_suffix="model2"
)
# Use sequence parallelism for model2
with patch_sequence_parallel_group(sp_group_model2):
# Run model2 with sequence parallelism
output2 = model2(input2)
```
"""
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
# Ensure the world size is compatible with the parallelism configuration
assert
world_size
%
sequence_model_parallel_size
==
0
,
\
f
"World size (
{
world_size
}
) must be divisible by sequence_model_parallel_size (
{
sequence_model_parallel_size
}
)"
# Build the sequence model-parallel groups.
num_sequence_model_parallel_groups
:
int
=
(
world_size
//
sequence_model_parallel_size
)
sp_group_ranks
=
[]
for
i
in
range
(
num_sequence_model_parallel_groups
):
# Create groups of consecutive ranks
ranks
=
list
(
range
(
i
*
sequence_model_parallel_size
,
(
i
+
1
)
*
sequence_model_parallel_size
))
sp_group_ranks
.
append
(
ranks
)
# Create SP group coordinator with a unique name
group_name
=
f
"sp_
{
group_name_suffix
}
"
if
group_name_suffix
else
"sp"
sp_group
=
init_model_parallel_group
(
sp_group_ranks
,
get_world_group
().
local_rank
,
backend
,
group_name
=
group_name
)
return
sp_group
FastVideo-main/fastvideo/v1/distributed/utils.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/utils.py
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
dataclasses
import
pickle
import
time
from
collections
import
deque
from
typing
import
Any
,
Deque
,
Dict
,
Optional
,
Sequence
,
Tuple
import
torch
from
torch.distributed
import
TCPStore
from
fastvideo.v1.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
ensure_divisibility
(
numerator
,
denominator
)
->
None
:
"""Ensure that numerator is divisible by the denominator."""
assert
numerator
%
denominator
==
0
,
"{} is not divisible by {}"
.
format
(
numerator
,
denominator
)
def
divide
(
numerator
:
int
,
denominator
:
int
)
->
int
:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
num_partitions
:
int
,
contiguous_split_chunks
:
bool
=
False
,
)
->
Sequence
[
torch
.
Tensor
]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
last_dim_size
=
divide
(
tensor
.
size
()[
last_dim
],
num_partitions
)
# Split.
tensor_list
=
torch
.
split
(
tensor
,
last_dim_size
,
dim
=
last_dim
)
# NOTE: torch.split does not create contiguous tensors by default.
if
contiguous_split_chunks
:
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tuple
(
tensor_list
)
@
dataclasses
.
dataclass
class
StatelessProcessGroup
:
"""A dataclass to hold a metadata store, and the rank, world_size of the
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
rank
:
int
world_size
:
int
store
:
torch
.
_C
.
_distributed_c10d
.
Store
data_expiration_seconds
:
int
=
3600
# 1 hour
# dst rank -> counter
send_dst_counter
:
Dict
[
int
,
int
]
=
dataclasses
.
field
(
default_factory
=
dict
)
# src rank -> counter
recv_src_counter
:
Dict
[
int
,
int
]
=
dataclasses
.
field
(
default_factory
=
dict
)
broadcast_send_counter
:
int
=
0
broadcast_recv_src_counter
:
Dict
[
int
,
int
]
=
dataclasses
.
field
(
default_factory
=
dict
)
# A deque to store the data entries, with key and timestamp.
entries
:
Deque
[
Tuple
[
str
,
float
]]
=
dataclasses
.
field
(
default_factory
=
deque
)
def
__post_init__
(
self
):
assert
self
.
rank
<
self
.
world_size
self
.
send_dst_counter
=
{
i
:
0
for
i
in
range
(
self
.
world_size
)}
self
.
recv_src_counter
=
{
i
:
0
for
i
in
range
(
self
.
world_size
)}
self
.
broadcast_recv_src_counter
=
{
i
:
0
for
i
in
range
(
self
.
world_size
)}
def
send_obj
(
self
,
obj
:
Any
,
dst
:
int
):
"""Send an object to a destination rank."""
self
.
expire_data
()
key
=
f
"send_to/
{
dst
}
/
{
self
.
send_dst_counter
[
dst
]
}
"
self
.
store
.
set
(
key
,
pickle
.
dumps
(
obj
))
self
.
send_dst_counter
[
dst
]
+=
1
self
.
entries
.
append
((
key
,
time
.
time
()))
def
expire_data
(
self
)
->
None
:
"""Expire data that is older than `data_expiration_seconds` seconds."""
while
self
.
entries
:
# check the oldest entry
key
,
timestamp
=
self
.
entries
[
0
]
if
time
.
time
()
-
timestamp
>
self
.
data_expiration_seconds
:
self
.
store
.
delete_key
(
key
)
self
.
entries
.
popleft
()
else
:
break
def
recv_obj
(
self
,
src
:
int
)
->
Any
:
"""Receive an object from a source rank."""
obj
=
pickle
.
loads
(
self
.
store
.
get
(
f
"send_to/
{
self
.
rank
}
/
{
self
.
recv_src_counter
[
src
]
}
"
))
self
.
recv_src_counter
[
src
]
+=
1
return
obj
def
broadcast_obj
(
self
,
obj
:
Optional
[
Any
],
src
:
int
)
->
Any
:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if
self
.
rank
==
src
:
self
.
expire_data
()
key
=
(
f
"broadcast_from/
{
src
}
/"
f
"
{
self
.
broadcast_send_counter
}
"
)
self
.
store
.
set
(
key
,
pickle
.
dumps
(
obj
))
self
.
broadcast_send_counter
+=
1
self
.
entries
.
append
((
key
,
time
.
time
()))
return
obj
else
:
key
=
(
f
"broadcast_from/
{
src
}
/"
f
"
{
self
.
broadcast_recv_src_counter
[
src
]
}
"
)
recv_obj
=
pickle
.
loads
(
self
.
store
.
get
(
key
))
self
.
broadcast_recv_src_counter
[
src
]
+=
1
return
recv_obj
def
all_gather_obj
(
self
,
obj
:
Any
)
->
list
[
Any
]:
"""All gather an object from all ranks."""
gathered_objs
=
[]
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
rank
:
gathered_objs
.
append
(
obj
)
self
.
broadcast_obj
(
obj
,
src
=
self
.
rank
)
else
:
recv_obj
=
self
.
broadcast_obj
(
None
,
src
=
i
)
gathered_objs
.
append
(
recv_obj
)
return
gathered_objs
def
barrier
(
self
):
"""A barrier to synchronize all ranks."""
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
rank
:
self
.
broadcast_obj
(
None
,
src
=
self
.
rank
)
else
:
self
.
broadcast_obj
(
None
,
src
=
i
)
@
staticmethod
def
create
(
host
:
str
,
port
:
int
,
rank
:
int
,
world_size
:
int
,
data_expiration_seconds
:
int
=
3600
,
)
->
"StatelessProcessGroup"
:
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
"""
# noqa
store
=
TCPStore
(
host_name
=
host
,
port
=
port
,
world_size
=
world_size
,
is_master
=
(
rank
==
0
),
)
return
StatelessProcessGroup
(
rank
=
rank
,
world_size
=
world_size
,
store
=
store
,
data_expiration_seconds
=
data_expiration_seconds
)
FastVideo-main/fastvideo/v1/entrypoints/__init__.py
0 → 100644
View file @
c07946d8
FastVideo-main/fastvideo/v1/entrypoints/cli/__init__.py
0 → 100644
View file @
c07946d8
FastVideo-main/fastvideo/v1/entrypoints/cli/cli_types.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/types.py
import
argparse
from
fastvideo.v1.utils
import
FlexibleArgumentParser
class
CLISubcommand
:
"""Base class for CLI subcommands"""
name
:
str
def
cmd
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
"""Execute the command with the given arguments"""
raise
NotImplementedError
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
"""Validate the arguments for this command"""
pass
def
subparser_init
(
self
,
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
"""Initialize the subparser for this command"""
raise
NotImplementedError
FastVideo-main/fastvideo/v1/entrypoints/cli/generate.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
import
argparse
from
typing
import
List
,
cast
from
fastvideo.v1.entrypoints.cli
import
utils
from
fastvideo.v1.entrypoints.cli.cli_types
import
CLISubcommand
from
fastvideo.v1.fastvideo_args
import
FastVideoArgs
from
fastvideo.v1.utils
import
FlexibleArgumentParser
class
GenerateSubcommand
(
CLISubcommand
):
"""The `generate` subcommand for the FastVideo CLI"""
def
__init__
(
self
)
->
None
:
self
.
name
=
"generate"
super
().
__init__
()
def
cmd
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
excluded_args
=
[
'subparser'
,
'config'
,
'num_gpus'
,
'master_port'
,
'dispatch_function'
]
# Create a filtered dictionary of arguments
filtered_args
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
k
not
in
excluded_args
and
v
is
not
None
}
main_args
=
[]
for
key
,
value
in
filtered_args
.
items
():
# Convert underscores to dashes in argument names
arg_name
=
f
"--
{
key
.
replace
(
'_'
,
'-'
)
}
"
# Handle boolean flags
if
isinstance
(
value
,
bool
):
if
value
:
main_args
.
append
(
arg_name
)
else
:
main_args
.
append
(
arg_name
)
main_args
.
append
(
str
(
value
))
utils
.
launch_distributed
(
args
.
num_gpus
,
main_args
,
master_port
=
args
.
master_port
)
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
if
args
.
num_gpus
is
not
None
and
args
.
num_gpus
<=
0
:
raise
ValueError
(
"Number of gpus must be positive"
)
if
args
.
master_port
is
not
None
and
(
args
.
master_port
<
1024
or
args
.
master_port
>
65535
):
raise
ValueError
(
"Master port must be between 1024 and 65535"
)
def
subparser_init
(
self
,
subparsers
:
argparse
.
_SubParsersAction
)
->
FlexibleArgumentParser
:
generate_parser
=
subparsers
.
add_parser
(
"generate"
,
help
=
"Run inference on a model"
,
usage
=
"fastvideo generate --model-path MODEL_PATH_OR_ID --prompt PROMPT [OPTIONS]"
)
generate_parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
''
,
required
=
False
,
help
=
"Read CLI options from a config YAML file."
)
generate_parser
.
add_argument
(
"--master-port"
,
type
=
int
,
default
=
None
,
help
=
"Port for the master process"
)
generate_parser
=
FastVideoArgs
.
add_cli_args
(
generate_parser
)
return
cast
(
FlexibleArgumentParser
,
generate_parser
)
def
cmd_init
()
->
List
[
CLISubcommand
]:
return
[
GenerateSubcommand
()]
FastVideo-main/fastvideo/v1/entrypoints/cli/main.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/main.py
from
typing
import
List
from
fastvideo.v1.entrypoints.cli.cli_types
import
CLISubcommand
from
fastvideo.v1.entrypoints.cli.generate
import
cmd_init
as
generate_cmd_init
from
fastvideo.v1.utils
import
FlexibleArgumentParser
def
cmd_init
()
->
List
[
CLISubcommand
]:
"""Initialize all commands from separate modules"""
commands
=
[]
commands
.
extend
(
generate_cmd_init
())
return
commands
def
main
()
->
None
:
parser
=
FlexibleArgumentParser
(
description
=
"FastVideo CLI"
)
parser
.
add_argument
(
'-v'
,
'--version'
,
action
=
'version'
,
version
=
'0.1.0'
)
subparsers
=
parser
.
add_subparsers
(
required
=
False
,
dest
=
"subparser"
)
cmds
=
{}
for
cmd
in
cmd_init
():
cmd
.
subparser_init
(
subparsers
).
set_defaults
(
dispatch_function
=
cmd
.
cmd
)
cmds
[
cmd
.
name
]
=
cmd
args
=
parser
.
parse_args
()
if
args
.
subparser
in
cmds
:
cmds
[
args
.
subparser
].
validate
(
args
)
if
hasattr
(
args
,
"dispatch_function"
):
args
.
dispatch_function
(
args
)
else
:
parser
.
print_help
()
if
__name__
==
"__main__"
:
main
()
FastVideo-main/fastvideo/v1/entrypoints/cli/utils.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
import
os
import
subprocess
import
sys
from
typing
import
List
,
Optional
from
fastvideo.v1.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
launch_distributed
(
num_gpus
:
int
,
args
:
List
[
str
],
master_port
:
Optional
[
int
]
=
None
)
->
int
:
"""
Launch a distributed job with the given arguments
Args:
num_gpus: Number of GPUs to use
args: Arguments to pass to v1_fastvideo_inference.py (defaults to sys.argv[1:])
master_port: Port for the master process (default: random)
"""
current_env
=
os
.
environ
.
copy
()
python_executable
=
sys
.
executable
project_root
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../../.."
))
main_script
=
os
.
path
.
join
(
project_root
,
"fastvideo/v1/sample/v1_fastvideo_inference.py"
)
cmd
=
[
python_executable
,
"-m"
,
"torch.distributed.run"
,
f
"--nproc_per_node=
{
num_gpus
}
"
]
if
master_port
is
not
None
:
cmd
.
append
(
f
"--master_port=
{
master_port
}
"
)
cmd
.
append
(
main_script
)
cmd
.
extend
(
args
)
logger
.
info
(
"Running inference with %d GPU(s)"
,
num_gpus
)
logger
.
info
(
"Launching command: %s"
,
" "
.
join
(
cmd
))
current_env
[
"PYTHONIOENCODING"
]
=
"utf-8"
process
=
subprocess
.
Popen
(
cmd
,
env
=
current_env
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
universal_newlines
=
True
,
bufsize
=
1
,
encoding
=
'utf-8'
,
errors
=
'replace'
)
if
process
.
stdout
:
for
line
in
iter
(
process
.
stdout
.
readline
,
''
):
print
(
line
.
strip
())
return
process
.
wait
()
FastVideo-main/fastvideo/v1/entrypoints/video_generator.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
"""
VideoGenerator module for FastVideo.
This module provides a consolidated interface for generating videos using
diffusion models.
"""
import
os
import
time
from
dataclasses
import
asdict
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
imageio
import
numpy
as
np
import
torch
import
torchvision
from
einops
import
rearrange
from
fastvideo.v1.configs.pipelines
import
(
PipelineConfig
,
get_pipeline_config_cls_for_name
)
from
fastvideo.v1.configs.sample
import
SamplingParam
from
fastvideo.v1.fastvideo_args
import
FastVideoArgs
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.pipelines
import
ForwardBatch
from
fastvideo.v1.utils
import
align_to
,
shallow_asdict
from
fastvideo.v1.worker.executor
import
Executor
logger
=
init_logger
(
__name__
)
class
VideoGenerator
:
"""
A unified class for generating videos using diffusion models.
This class provides a simple interface for video generation with rich
customization options, similar to popular frameworks like HF Diffusers.
"""
def
__init__
(
self
,
fastvideo_args
:
FastVideoArgs
,
executor_class
:
type
[
Executor
],
log_stats
:
bool
):
"""
Initialize the video generator.
Args:
pipeline: The pipeline to use for inference
fastvideo_args: The inference arguments
"""
self
.
fastvideo_args
=
fastvideo_args
self
.
executor
=
executor_class
(
fastvideo_args
)
@
classmethod
def
from_pretrained
(
cls
,
model_path
:
str
,
device
:
Optional
[
str
]
=
None
,
torch_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
pipeline_config
:
Optional
[
Union
[
str
|
PipelineConfig
]]
=
None
,
**
kwargs
)
->
"VideoGenerator"
:
"""
Create a video generator from a pretrained model.
Args:
model_path: Path or identifier for the pretrained model
device: Device to load the model on (e.g., "cuda", "cuda:0", "cpu")
torch_dtype: Data type for model weights (e.g., torch.float16)
**kwargs: Additional arguments to customize model loading
Returns:
The created video generator
Priority level: Default pipeline config < User's pipeline config < User's kwargs
"""
config
=
None
# 1. If users provide a pipeline config, it will override the default pipeline config
if
isinstance
(
pipeline_config
,
PipelineConfig
):
config
=
pipeline_config
else
:
config_cls
=
get_pipeline_config_cls_for_name
(
model_path
)
if
config_cls
is
not
None
:
config
=
config_cls
()
if
isinstance
(
pipeline_config
,
str
):
config
.
load_from_json
(
pipeline_config
)
# 2. If users also provide some kwargs, it will override the pipeline config.
# The user kwargs shouldn't contain model config parameters!
if
config
is
None
:
logger
.
warning
(
"No config found for model %s, using default config"
,
model_path
)
config_args
=
kwargs
else
:
config_args
=
shallow_asdict
(
config
)
config_args
.
update
(
kwargs
)
fastvideo_args
=
FastVideoArgs
(
model_path
=
model_path
,
device_str
=
device
or
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
,
**
config_args
)
fastvideo_args
.
check_fastvideo_args
()
return
cls
.
from_fastvideo_args
(
fastvideo_args
)
@
classmethod
def
from_fastvideo_args
(
cls
,
fastvideo_args
:
FastVideoArgs
)
->
"VideoGenerator"
:
"""
Create a video generator with the specified arguments.
Args:
fastvideo_args: The inference arguments
Returns:
The created video generator
"""
# Initialize distributed environment if needed
# initialize_distributed_and_parallelism(fastvideo_args)
executor_class
=
Executor
.
get_class
(
fastvideo_args
)
return
cls
(
fastvideo_args
=
fastvideo_args
,
executor_class
=
executor_class
,
log_stats
=
False
,
# TODO: implement
)
def
generate_video
(
self
,
prompt
:
str
,
sampling_param
:
Optional
[
SamplingParam
]
=
None
,
**
kwargs
,
)
->
Union
[
Dict
[
str
,
Any
],
List
[
np
.
ndarray
]]:
"""
Generate a video based on the given prompt.
Args:
prompt: The prompt to use for generation
negative_prompt: The negative prompt to use (overrides the one in fastvideo_args)
output_path: Path to save the video (overrides the one in fastvideo_args)
save_video: Whether to save the video to disk
return_frames: Whether to return the raw frames
num_inference_steps: Number of denoising steps (overrides fastvideo_args)
guidance_scale: Classifier-free guidance scale (overrides fastvideo_args)
num_frames: Number of frames to generate (overrides fastvideo_args)
height: Height of generated video (overrides fastvideo_args)
width: Width of generated video (overrides fastvideo_args)
fps: Frames per second for saved video (overrides fastvideo_args)
seed: Random seed for generation (overrides fastvideo_args)
callback: Callback function called after each step
callback_steps: Number of steps between each callback
Returns:
Either the output dictionary or the list of frames depending on return_frames
"""
# Create a copy of inference args to avoid modifying the original
fastvideo_args
=
self
.
fastvideo_args
# Validate inputs
if
not
isinstance
(
prompt
,
str
):
raise
TypeError
(
f
"`prompt` must be a string, but got
{
type
(
prompt
)
}
"
)
prompt
=
prompt
.
strip
()
if
sampling_param
is
None
:
sampling_param
=
SamplingParam
.
from_pretrained
(
fastvideo_args
.
model_path
)
kwargs
[
"prompt"
]
=
prompt
sampling_param
.
update
(
kwargs
)
# Process negative prompt
if
sampling_param
.
negative_prompt
is
not
None
:
sampling_param
.
negative_prompt
=
sampling_param
.
negative_prompt
.
strip
(
)
# Validate dimensions
if
(
sampling_param
.
height
<=
0
or
sampling_param
.
width
<=
0
or
sampling_param
.
num_frames
<=
0
):
raise
ValueError
(
f
"Height, width, and num_frames must be positive integers, got "
f
"height=
{
sampling_param
.
height
}
, width=
{
sampling_param
.
width
}
, "
f
"num_frames=
{
sampling_param
.
num_frames
}
"
)
if
(
sampling_param
.
num_frames
-
1
)
%
fastvideo_args
.
vae_config
.
arch_config
.
temporal_compression_ratio
!=
0
:
raise
ValueError
(
f
"num_frames-1 must be a multiple of
{
fastvideo_args
.
vae_config
.
arch_config
.
temporal_compression_ratio
}
, got
{
sampling_param
.
num_frames
}
"
)
# Calculate sizes
target_height
=
align_to
(
sampling_param
.
height
,
16
)
target_width
=
align_to
(
sampling_param
.
width
,
16
)
# Calculate latent sizes
latents_size
=
[(
sampling_param
.
num_frames
-
1
)
//
4
+
1
,
sampling_param
.
height
//
8
,
sampling_param
.
width
//
8
]
n_tokens
=
latents_size
[
0
]
*
latents_size
[
1
]
*
latents_size
[
2
]
# Log parameters
debug_str
=
f
"""
height:
{
target_height
}
width:
{
target_width
}
video_length:
{
sampling_param
.
num_frames
}
prompt:
{
prompt
}
neg_prompt:
{
sampling_param
.
negative_prompt
}
seed:
{
sampling_param
.
seed
}
infer_steps:
{
sampling_param
.
num_inference_steps
}
num_videos_per_prompt:
{
sampling_param
.
num_videos_per_prompt
}
guidance_scale:
{
sampling_param
.
guidance_scale
}
n_tokens:
{
n_tokens
}
flow_shift:
{
fastvideo_args
.
flow_shift
}
embedded_guidance_scale:
{
fastvideo_args
.
embedded_cfg_scale
}
"""
logger
.
info
(
debug_str
)
# Prepare batch
batch
=
ForwardBatch
(
**
asdict
(
sampling_param
),
eta
=
0.0
,
n_tokens
=
n_tokens
,
extra
=
{},
)
# Run inference
start_time
=
time
.
time
()
output_batch
=
self
.
executor
.
execute_forward
(
batch
,
fastvideo_args
)
samples
=
output_batch
gen_time
=
time
.
time
()
-
start_time
logger
.
info
(
"Generated successfully in %.2f seconds"
,
gen_time
)
# Process outputs
videos
=
rearrange
(
samples
,
"b c t h w -> t b c h w"
)
frames
=
[]
for
x
in
videos
:
x
=
torchvision
.
utils
.
make_grid
(
x
,
nrow
=
6
)
x
=
x
.
transpose
(
0
,
1
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
frames
.
append
((
x
*
255
).
numpy
().
astype
(
np
.
uint8
))
# Save video if requested
if
batch
.
save_video
:
save_path
=
batch
.
output_path
if
save_path
:
os
.
makedirs
(
os
.
path
.
dirname
(
save_path
),
exist_ok
=
True
)
video_path
=
os
.
path
.
join
(
save_path
,
f
"
{
prompt
[:
100
]
}
.mp4"
)
imageio
.
mimsave
(
video_path
,
frames
,
fps
=
batch
.
fps
)
logger
.
info
(
"Saved video to %s"
,
video_path
)
else
:
logger
.
warning
(
"No output path provided, video not saved"
)
if
batch
.
return_frames
:
return
frames
else
:
return
{
"samples"
:
samples
,
"prompts"
:
prompt
,
"size"
:
(
target_height
,
target_width
,
batch
.
num_frames
),
"generation_time"
:
gen_time
}
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment