Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ktransformers
Commits
877aec85
Unverified
Commit
877aec85
authored
Apr 09, 2025
by
Yuhao Tsui
Committed by
GitHub
Apr 09, 2025
Browse files
Merge branch 'kvcache-ai:main' into main
parents
84164f58
9037bf30
Changes
251
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3320 additions
and
26 deletions
+3320
-26
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
...s/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+2
-2
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
...rmers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
+92
-0
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+1
-1
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
...mers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
+94
-0
ktransformers/server/args.py
ktransformers/server/args.py
+30
-9
ktransformers/server/backend/args.py
ktransformers/server/backend/args.py
+1
-10
ktransformers/server/backend/context_manager.py
ktransformers/server/backend/context_manager.py
+12
-1
ktransformers/server/backend/interfaces/balance_serve.py
ktransformers/server/backend/interfaces/balance_serve.py
+410
-0
ktransformers/server/backend/interfaces/ktransformers.py
ktransformers/server/backend/interfaces/ktransformers.py
+2
-2
ktransformers/server/backend/interfaces/transformers.py
ktransformers/server/backend/interfaces/transformers.py
+3
-1
ktransformers/server/balance_serve/inference/__init__.py
ktransformers/server/balance_serve/inference/__init__.py
+0
-0
ktransformers/server/balance_serve/inference/config.py
ktransformers/server/balance_serve/inference/config.py
+142
-0
ktransformers/server/balance_serve/inference/distributed/__init__.py
...rs/server/balance_serve/inference/distributed/__init__.py
+3
-0
ktransformers/server/balance_serve/inference/distributed/communication_op.py
...r/balance_serve/inference/distributed/communication_op.py
+39
-0
ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py
...erver/balance_serve/inference/distributed/cuda_wrapper.py
+168
-0
ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py
.../balance_serve/inference/distributed/custom_all_reduce.py
+310
-0
ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py
...ce_serve/inference/distributed/custom_all_reduce_utils.py
+272
-0
ktransformers/server/balance_serve/inference/distributed/parallel_state.py
...ver/balance_serve/inference/distributed/parallel_state.py
+1262
-0
ktransformers/server/balance_serve/inference/distributed/pynccl.py
...mers/server/balance_serve/inference/distributed/pynccl.py
+201
-0
ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py
...ver/balance_serve/inference/distributed/pynccl_wrapper.py
+276
-0
No files found.
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
View file @
877aec85
...
...
@@ -66,7 +66,7 @@
name
:
"
^model
\\
.layers
\\
.(0|[1-9]|[12][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
@@ -74,7 +74,7 @@
name
:
"
^model
\\
.layers
\\
.([3456][0-9])
\\
.mlp
\\
.gate$"
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
# mlp module with custom forward function
class
:
ktransformers.operators.gate.KMoEGate
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda:1"
prefill_device
:
"
cuda:1"
...
...
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml
0 → 100644
View file @
877aec85
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoEV2
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExpertsV2
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.flashinfer_attn
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace
:
class
:
ktransformers.operators.layernorm.RMSNorm
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace
:
class
:
ktransformers.operators.mlp.kDeepseekV3MLP
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
\ No newline at end of file
ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
View file @
877aec85
...
...
@@ -38,7 +38,7 @@
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
DeepSeekV3
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
...
...
ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml
0 → 100644
View file @
877aec85
-
match
:
name
:
"
^lm_head$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
.(?!.*self_attn
\\
.kv_b_proj).*$"
# regular expression
class
:
torch.nn.Linear
# only match modules matching name and class simultaneously
replace
:
class
:
ktransformers.operators.linear.KTransformersLinear
# optimized Kernel on quantized data types
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
generate_op
:
"
VLinearMarlin"
prefill_op
:
"
KLinearTorch"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp$"
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace
:
class
:
ktransformers.operators.experts.KDeepseekV3MoEV2
# mlp module with custom forward function
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.MoEGate
replace
:
class
:
ktransformers.operators.gate.KMoEGate
kwargs
:
generate_device
:
"
cuda:0"
prefill_device
:
"
cuda:0"
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.mlp
\\
.experts$"
replace
:
class
:
ktransformers.operators.experts.KTransformersExpertsV2
# custom MoE Kernel with expert paralleism
kwargs
:
prefill_device
:
"
cuda"
prefill_op
:
"
KExpertsTorch"
generate_device
:
"
cpu"
generate_op
:
"
KExpertsCPU"
out_device
:
"
cuda"
recursive
:
False
# don't recursively inject submodules of this module
-
match
:
name
:
"
^model
\\
.layers
\\
..*
\\
.self_attn$"
replace
:
class
:
ktransformers.operators.attention.flashinfer_attn
# optimized MLA implementation
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
absorb_for_prefill
:
False
# change this to True to enable long context(prefill may slower).
-
match
:
name
:
"
^model$"
replace
:
class
:
"
ktransformers.operators.models.KDeepseekV2Model"
kwargs
:
per_layer_prefill_intput_threshold
:
0
# 0 is close layer wise prefill
-
match
:
name
:
"
^model.embed_tokens"
replace
:
class
:
"
default"
kwargs
:
generate_device
:
"
cpu"
prefill_device
:
"
cpu"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace
:
class
:
ktransformers.operators.layernorm.RMSNorm
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace
:
class
:
ktransformers.operators.mlp.kDeepseekV3MLP
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
-
match
:
class
:
ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace
:
class
:
ktransformers.operators.RoPE.RotaryEmbeddingV4
kwargs
:
generate_device
:
"
cuda"
prefill_device
:
"
cuda"
\ No newline at end of file
ktransformers/server/args.py
View file @
877aec85
import
argparse
from
ktransformers.server.backend.args
import
ConfigArgs
,
default_args
from
ktransformers.util.utils
import
get_free_ports
class
ArgumentParser
:
def
__init__
(
self
,
cfg
):
...
...
@@ -16,20 +16,18 @@ class ArgumentParser:
parser
.
add_argument
(
"--web"
,
type
=
bool
,
default
=
self
.
cfg
.
mount_web
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
self
.
cfg
.
model_name
)
parser
.
add_argument
(
"--model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
self
.
cfg
.
model_path
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
self
.
cfg
.
model_device
,
help
=
"Warning: Abandoning this parameter"
)
parser
.
add_argument
(
"--gguf_path"
,
type
=
str
,
default
=
self
.
cfg
.
gguf_path
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
self
.
cfg
.
optimize_config_path
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--optimize_config_path"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--cpu_infer"
,
type
=
int
,
default
=
self
.
cfg
.
cpu_infer
)
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_
prefill_
size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--
backend_
type"
,
type
=
str
,
default
=
self
.
cfg
.
backend_type
)
parser
.
add_argument
(
"--chunk_size"
,
type
=
int
,
default
=
self
.
cfg
.
chunk_size
)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser
.
add_argument
(
"--paged"
,
type
=
bool
,
default
=
self
.
cfg
.
paged
)
parser
.
add_argument
(
"--total_context"
,
type
=
int
,
default
=
self
.
cfg
.
total_context
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
max_batch_size
)
parser
.
add_argument
(
"--max_new_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_new_tokens
)
parser
.
add_argument
(
"--json_mode"
,
type
=
bool
,
default
=
self
.
cfg
.
json_mode
)
...
...
@@ -62,7 +60,6 @@ class ArgumentParser:
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
self
.
cfg
.
repetition_penalty
)
parser
.
add_argument
(
"--frequency_penalty"
,
type
=
float
,
default
=
self
.
cfg
.
frequency_penalty
)
parser
.
add_argument
(
"--presence_penalty"
,
type
=
float
,
default
=
self
.
cfg
.
presence_penalty
)
parser
.
add_argument
(
"--max_response_tokens"
,
type
=
int
,
default
=
self
.
cfg
.
max_response_tokens
)
parser
.
add_argument
(
"--response_chunk"
,
type
=
int
,
default
=
self
.
cfg
.
response_chunk
)
parser
.
add_argument
(
"--no_code_formatting"
,
type
=
bool
,
default
=
self
.
cfg
.
no_code_formatting
)
parser
.
add_argument
(
"--cache_8bit"
,
type
=
bool
,
default
=
self
.
cfg
.
cache_8bit
)
...
...
@@ -73,6 +70,9 @@ class ArgumentParser:
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
self
.
cfg
.
batch_size
)
parser
.
add_argument
(
"--cache_lens"
,
type
=
int
,
default
=
self
.
cfg
.
cache_lens
)
# kvc2 config
parser
.
add_argument
(
"--kvc2_config_dir"
,
type
=
str
,
default
=
self
.
cfg
.
kvc2_config_dir
)
# log configs
# log level: debug, info, warn, error, crit
parser
.
add_argument
(
"--log_dir"
,
type
=
str
,
default
=
self
.
cfg
.
log_dir
)
...
...
@@ -103,6 +103,18 @@ class ArgumentParser:
# local chat
parser
.
add_argument
(
"--prompt_file"
,
type
=
str
,
default
=
self
.
cfg
.
prompt_file
)
# async server
parser
.
add_argument
(
"--sched_strategy"
,
type
=
str
,
default
=
self
.
cfg
.
sched_strategy
)
# parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port)
# parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port)
# parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port)
parser
.
add_argument
(
"--page_size"
,
type
=
str
,
default
=
self
.
cfg
.
page_size
)
parser
.
add_argument
(
"--memory_gpu_only"
,
type
=
str
,
default
=
self
.
cfg
.
memory_gpu_only
)
parser
.
add_argument
(
"--utilization_percentage"
,
type
=
str
,
default
=
self
.
cfg
.
utilization_percentage
)
parser
.
add_argument
(
"--cpu_memory_size_GB"
,
type
=
str
,
default
=
self
.
cfg
.
cpu_memory_size_GB
)
args
=
parser
.
parse_args
()
if
(
args
.
model_dir
is
not
None
or
args
.
model_path
is
not
None
):
if
(
args
.
model_path
is
not
None
):
...
...
@@ -123,6 +135,15 @@ class ArgumentParser:
self
.
cfg
.
mount_web
=
args
.
web
self
.
cfg
.
server_ip
=
args
.
host
self
.
cfg
.
server_port
=
args
.
port
self
.
cfg
.
backend_type
=
args
.
type
self
.
cfg
.
user_force_think
=
args
.
force_think
args
.
gpu_memory_size
=
args
.
cache_lens
*
2
*
576
*
61
self
.
cfg
.
gpu_memory_size
=
args
.
gpu_memory_size
free_ports
=
get_free_ports
(
3
,
[
args
.
port
])
args
.
sched_port
=
free_ports
[
0
]
args
.
sched_metrics_port
=
free_ports
[
1
]
args
.
kvc2_metrics_port
=
free_ports
[
2
]
self
.
cfg
.
sched_port
=
free_ports
[
0
]
self
.
cfg
.
sched_metrics_port
=
free_ports
[
1
]
self
.
cfg
.
kvc2_metrics_port
=
free_ports
[
2
]
return
args
ktransformers/server/backend/args.py
View file @
877aec85
...
...
@@ -12,18 +12,10 @@ class ConfigArgs(BaseModel):
class
Config
:
protected_namespaces
=
()
paged
:
bool
=
Field
(
None
,
description
=
"Whether to use paged attention kv cache"
)
total_context
:
int
=
Field
(
None
,
description
=
(
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
" total to distribute dynamically over however many jobs are active at once"
),
)
max_batch_size
:
int
=
Field
(
None
,
description
=
"Max number of batches to run at once, assuming the sequences will fit within total_context"
)
chunk_
prefill_
size
:
int
=
Field
(
chunk_size
:
int
=
Field
(
None
,
description
=
(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
...
...
@@ -70,7 +62,6 @@ class ConfigArgs(BaseModel):
repetition_penalty
:
float
=
Field
(
None
,
description
=
"Sampler repetition penalty, default = 1.01 (1 to disable)"
)
frequency_penalty
:
float
=
Field
(
None
,
description
=
"Sampler frequency penalty, default = 0.0 (0 to disable)"
)
presence_penalty
:
float
=
Field
(
None
,
description
=
"Sampler presence penalty, default = 0.0 (0 to disable)"
)
max_response_tokens
:
int
=
Field
(
None
,
description
=
"Max tokens per response, default = 1000"
)
response_chunk
:
int
=
Field
(
None
,
description
=
"Space to reserve in context for reply, default = 250"
)
no_code_formatting
:
bool
=
Field
(
None
,
description
=
"Disable code formatting/syntax highlighting"
)
cache_8bit
:
bool
=
Field
(
None
,
description
=
"Use 8-bit (FP8) cache"
)
...
...
ktransformers/server/backend/context_manager.py
View file @
877aec85
...
...
@@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersThreadContext
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaThreadContext
from
ktransformers.server.backend.interfaces.exllamav2
import
ExllamaInterface
from
ktransformers.server.backend.interfaces.transformers
import
TransformersInterface
from
ktransformers.server.backend.interfaces.ktransformers
import
KTransformersInterface
class
ThreadContextManager
:
lock
:
Lock
threads_context
:
Dict
[
ObjectID
,
ThreadContext
]
...
...
@@ -36,7 +38,16 @@ class ThreadContextManager:
elif
isinstance
(
self
.
interface
,
TransformersInterface
):
new_context
=
TransformersThreadContext
(
run
,
self
.
interface
)
else
:
raise
NotImplementedError
from
ktransformers.server.backend.interfaces.balance_serve
import
BalanceServeThreadContext
from
ktransformers.server.backend.interfaces.balance_serve
import
BalanceServeInterface
if
isinstance
(
self
.
interface
,
BalanceServeInterface
):
new_context
=
BalanceServeThreadContext
(
run
,
self
.
interface
)
else
:
raise
NotImplementedError
# elif isinstance(self.interface, BalanceServeInterface):
# new_context = BalanceServeThreadContext(run, self.interface)
# else:
# raise NotImplementedError
self
.
threads_context
[
run
.
thread_id
]
=
new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re
=
self
.
threads_context
[
run
.
thread_id
]
...
...
ktransformers/server/backend/interfaces/balance_serve.py
0 → 100644
View file @
877aec85
from
typing
import
Any
,
AsyncIterator
,
List
,
Optional
,
Set
from
ktransformers.models.custom_cache
import
KDeepSeekV3Cache
from
transformers
import
(
AutoTokenizer
,
AutoConfig
,
GenerationConfig
,
StaticCache
,
AutoModelForCausalLM
,
BitsAndBytesConfig
,
)
from
ktransformers.server.config.config
import
Config
from
..base
import
ThreadContext
,
BackendInterfaceBase
import
torch
from
ktransformers.server.backend.interfaces.transformers
import
(
ConfigArgs
,
default_args
,
TextStreamer
,
)
from
ktransformers.server.schemas.base
import
ObjectID
from
ktransformers.server.config.log
import
logger
from
ktransformers.optimize.optimize
import
optimize_and_load_gguf
from
ktransformers.models.custom_modeling_deepseek_v3
import
KDeepseekV3ForCausalLM
from
ktransformers.models.custom_modeling_deepseek_v2
import
KDeepseekV2ForCausalLM
from
ktransformers.server.balance_serve.inference.model_runner
import
ModelRunner
from
ktransformers.server.balance_serve.inference.sampling.sampler
import
Sampler
,
SamplingOptions
from
ktransformers.server.balance_serve.inference.query_manager
import
QueryManager
from
ktransformers.server.balance_serve.inference.forward_batch
import
ForwardBatchInput
,
ForwardBatchOutput
from
ktransformers.server.balance_serve.sched_rpc
import
SchedulerClient
from
ktransformers.server.balance_serve.settings
import
sched_ext
from
torch.multiprocessing
import
Queue
import
torch.multiprocessing
as
mp
from
ktransformers.server.schemas.endpoints.chat
import
RawUsage
from
ktransformers.server.utils.multi_timer
import
Profiler
import
zmq
import
time
import
queue
import
tempfile
import
asyncio
import
threading
from
contextlib
import
asynccontextmanager
from
fastapi
import
FastAPI
,
Request
import
os
ktransformer_rules_dir
=
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
".."
,
".."
,
".."
,
"./optimize/optimize_rules/"
)
)
default_optimize_rules
=
{
"DeepseekV3ForCausalLM"
:
ktransformer_rules_dir
+
"DeepSeek-V3-Chat-serve.yaml"
,
"Qwen2MoeForCausalLM"
:
ktransformer_rules_dir
+
"Qwen2-57B-A14B-Instruct-serve.yaml"
,
}
async
def
chat_stream
(
queue
:
asyncio
.
Queue
,
tokenizer
:
AutoTokenizer
):
streamer
=
TextStreamer
(
tokenizer
)
while
True
:
token
=
await
queue
.
get
()
#print(f"Got token: {token}")
if
token
is
None
:
# str = f'{token}\n\n'
# str = model.tokenizer.decode(token)
s
=
streamer
.
end
()
if
s
is
not
None
:
yield
s
break
# str = model.tokenizer.decode(token)
yield
streamer
.
put
(
token
)
def
fill_generated_tokens
(
query_updates
:
list
[
sched_ext
.
QueryUpdate
],
generated_tokens
:
torch
.
Tensor
,
query_manager
:
QueryManager
=
None
):
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
for
i
in
range
(
generated_tokens
.
size
(
0
)):
print
(
generated_tokens
[
i
].
item
())
query_updates
[
i
].
generated_token
=
generated_tokens
[
i
].
item
()
if
not
query_manager
.
query_map
[
query_updates
[
i
].
id
].
is_prefill
:
pos
=
query_updates
[
i
].
active_position
query_manager
.
query_map
[
query_updates
[
i
].
id
].
query_tokens
[
pos
]
=
generated_tokens
[
i
]
def
report_last_time_performance
(
profiler
:
Profiler
):
try
:
tokenize_time
=
profiler
.
get_timer_sec
(
'tokenize'
)
prefill_time
=
profiler
.
get_timer_sec
(
'prefill'
)
decode_time
=
profiler
.
get_timer_sec
(
'decode'
)
prefill_count
=
profiler
.
get_counter
(
'prefill'
)
decode_count
=
profiler
.
get_counter
(
'decode'
)
logger
.
info
(
f
'Performance(T/s): prefill
{
prefill_count
/
prefill_time
}
, decode
{
decode_count
/
decode_time
}
. Time(s): tokenize
{
tokenize_time
}
, prefill
{
prefill_time
}
, decode
{
decode_time
}
'
)
except
:
logger
.
info
(
f
'Performance statistics not recorded'
)
class
Engine
:
sched_client
:
SchedulerClient
updates
:
list
[
sched_ext
.
QueryUpdate
]
batch
:
sched_ext
.
BatchQueryTodo
model_runner
:
ModelRunner
sampler
:
Sampler
query_manager
:
QueryManager
cache
:
KDeepSeekV3Cache
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
,
generated_token_queue
:
Queue
=
None
,
broadcast_endpoint
:
str
=
None
):
self
.
args
=
args
# 子进程和父进程无法共享 config 变量
for
key
,
value
in
vars
(
args
).
items
():
if
value
is
not
None
and
hasattr
(
Config
(),
key
):
setattr
(
Config
(),
key
,
value
)
self
.
device
=
self
.
args
.
device
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
updates
=
[]
config
=
AutoConfig
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
self
.
cache
=
KDeepSeekV3Cache
(
config
,
self
.
args
.
page_size
)
self
.
gen_queue
=
generated_token_queue
print
(
f
"Getting inference context from sched_client."
)
inference_context
=
self
.
sched_client
.
get_inference_context_raw
()
print
(
f
"Got inference context, sending it to subscribers."
)
inference_context
=
self
.
sched_client
.
rebuild_inferece_context
(
inference_context
)
self
.
cache
.
load
(
inference_context
)
print
(
f
"kv_cache loaded successfully."
)
self
.
block_num
=
inference_context
.
k_cache
[
0
].
size
(
1
)
with
torch
.
device
(
"meta"
):
if
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
:
self
.
model
=
KDeepseekV3ForCausalLM
(
config
,
self
.
cache
)
elif
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
:
self
.
model
=
KDeepseekV2ForCausalLM
(
config
,
self
.
cache
)
# print(self.block_num)
context
=
zmq
.
Context
()
self
.
pub_socket
=
context
.
socket
(
zmq
.
PUB
)
self
.
pub_socket
.
bind
(
f
"ipc://
{
broadcast_endpoint
}
"
)
# time.sleep(1) # make sure all subscribers are ready
try
:
generation_config
=
GenerationConfig
.
from_pretrained
(
args
.
model_dir
)
except
:
generation_config
=
GenerationConfig
(
max_length
=
args
.
max_new_tokens
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
do_sample
=
True
)
if
args
.
optimize_config_path
is
None
:
optimize_config_path
=
default_optimize_rules
[
config
.
architectures
[
0
]]
else
:
optimize_config_path
=
args
.
optimize_config_path
gguf_path
=
args
.
gguf_path
if
gguf_path
is
None
:
gguf_path
=
input
(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf
(
self
.
model
,
optimize_config_path
,
gguf_path
,
config
)
self
.
model
.
generation_config
=
generation_config
if
self
.
model
.
generation_config
.
pad_token_id
is
None
:
self
.
model
.
generation_config
.
pad_token_id
=
self
.
model
.
generation_config
.
eos_token_id
self
.
model
.
eval
()
#@TODO add config
self
.
model
.
init_wrapper
(
self
.
args
.
use_cuda_graph
,
self
.
device
,
args
.
max_batch_size
,
self
.
block_num
)
self
.
model_runner
=
ModelRunner
(
self
.
model
,
self
.
device
,
self
.
args
.
use_cuda_graph
,
page_size
=
args
.
page_size
)
self
.
sampler
=
Sampler
()
self
.
query_manager
=
QueryManager
(
device
=
self
.
device
,
page_size
=
args
.
page_size
)
def
sampling
(
self
,
forward_output
:
ForwardBatchOutput
):
generated_tokens
=
torch
.
empty
(
0
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
for
i
in
range
(
forward_output
.
num_batchs
):
logit
=
forward_output
.
logits
[
i
]
if
hasattr
(
forward_output
,
"temperatures"
):
temperatures
=
forward_output
.
temperatures
[
i
]
else
:
temperatures
=
None
if
hasattr
(
forward_output
,
"top_ps"
):
top_ps
=
forward_output
.
top_ps
[
i
]
else
:
top_ps
=
None
sample_options
=
SamplingOptions
(
logit
.
size
(
0
),
self
.
device
,
pretrained_config
=
self
.
model
.
generation_config
,
temperatures
=
temperatures
,
top_ps
=
top_ps
)
generated_tokens
,
probs
=
self
.
sampler
(
logit
,
sample_options
)
return
generated_tokens
,
probs
def
loop
(
self
):
next_batch
=
None
while
True
:
self
.
batch
=
next_batch
if
self
.
batch
is
not
None
:
self
.
model_runner
.
run
(
self
.
batch
,
self
.
query_manager
)
if
len
(
self
.
updates
)
>
0
:
for
q
in
self
.
updates
:
if
q
.
is_prefill
==
True
:
continue
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
try
:
self
.
gen_queue
.
put
((
q
.
id
,
q
.
generated_token
if
q
.
decode_done
==
False
else
None
),
timeout
=
5
)
except
queue
.
Full
:
pass
#print("Queue is full after timeout; unable to put more items.")
next_batch
=
self
.
sched_client
.
update_last_batch
(
self
.
updates
)
if
next_batch
.
query_ids
==
[]:
next_batch
=
None
self
.
pub_socket
.
send_pyobj
(
next_batch
)
if
next_batch
is
not
None
:
self
.
query_manager
.
add_query
(
next_batch
)
if
self
.
batch
is
not
None
:
self
.
model_runner
.
sync
()
print
(
f
"Model execution time (GPU):
{
self
.
model_runner
.
model_time
:.
3
f
}
ms"
)
# if self.rank == 0:
generated_tokens
,
probs
=
self
.
sampling
(
self
.
model_runner
.
output
)
self
.
updates
=
self
.
query_manager
.
update
(
self
.
batch
)
fill_generated_tokens
(
self
.
updates
,
generated_tokens
,
self
.
query_manager
)
else
:
self
.
updates
=
[]
class
BalanceServeThreadContext
(
ThreadContext
):
def
get_local_messages
(
self
):
local_messages
=
[]
for
m
in
self
.
messages
:
local_messages
.
append
({
"role"
:
m
.
role
.
value
,
"content"
:
m
.
get_text_content
()})
return
local_messages
def
run_engine
(
args
,
token_queue
,
broadcast_endpoint
,
event
):
engine
=
Engine
(
args
,
token_queue
,
broadcast_endpoint
)
if
args
.
use_cuda_graph
:
engine
.
model_runner
.
warmup
()
event
.
set
()
engine
.
loop
()
class
BalanceServeInterface
(
BackendInterfaceBase
):
use_static_cache
:
bool
=
True
model
:
Any
tokenizer
:
AutoTokenizer
cache
:
StaticCache
generated_ids
:
torch
.
Tensor
seq_length
:
int
streamer
:
TextStreamer
# thread_related
last_request_id
:
Optional
[
str
]
=
None
ever_generated_ids
:
Set
[
int
]
=
set
()
def
__init__
(
self
,
args
:
ConfigArgs
=
default_args
):
self
.
args
=
args
self
.
queue_map
:
dict
[
int
,
asyncio
.
Queue
]
=
{}
self
.
thread_map
:
dict
[
int
,
int
]
=
{}
processes
=
[]
self
.
broadcast_endpoint
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
).
name
# @TODO add to config
ctx
=
mp
.
get_context
(
"spawn"
)
self
.
token_queue
=
ctx
.
Queue
(
maxsize
=
1000
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_dir
,
trust_remote_code
=
True
)
self
.
sched_client
=
SchedulerClient
(
args
.
sched_port
)
self
.
streamer
=
TextStreamer
(
self
.
tokenizer
)
start_event
=
ctx
.
Event
()
p
=
ctx
.
Process
(
target
=
run_engine
,
args
=
(
self
.
args
,
self
.
token_queue
,
self
.
broadcast_endpoint
,
start_event
))
p
.
start
()
processes
.
append
(
p
)
start_event
.
wait
()
def
run_queue_proxy
(
self
):
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
loop
.
run_until_complete
(
self
.
queue_proxy
())
@
asynccontextmanager
async
def
lifespan
(
self
,
app
:
FastAPI
):
asyncio
.
create_task
(
self
.
queue_proxy
())
yield
async
def
queue_proxy
(
self
):
print
(
"Queue Proxy Started"
)
while
True
:
try
:
query_id
,
token
=
self
.
token_queue
.
get_nowait
()
try
:
# query id might not be allocated yet
self
.
queue_map
[
query_id
].
put_nowait
(
token
)
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
except
asyncio
.
QueueFull
:
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
await
self
.
queue_map
[
query_id
].
put
(
token
)
except
queue
.
Empty
:
# print("no new token")
# await asyncio.sleep(1)
await
asyncio
.
sleep
(
0
)
def
tokenize_prompt
(
self
,
prompt
:
str
):
input_ids
=
self
.
tokenizer
.
encode
(
prompt
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
return
input_ids
def
format_and_tokenize_input_ids
(
self
,
thread_id
:
ObjectID
,
messages
:
List
):
for
m
in
messages
:
if
m
[
"role"
]
==
"system"
:
logger
.
warning
(
f
'change
{
m
[
"role"
]
}
to user'
)
m
[
"role"
]
=
"user"
new_messages
=
[
messages
[
0
]]
for
m
in
messages
[
1
:]:
if
m
[
"role"
]
==
"user"
and
new_messages
[
-
1
][
"role"
]
==
"user"
:
logger
.
warning
(
"merge two adjacent user messages"
)
new_messages
[
-
1
][
"content"
]
+=
'
\n
'
+
m
[
"content"
]
else
:
new_messages
.
append
(
m
)
input_str
:
str
=
self
.
tokenizer
.
apply_chat_template
(
new_messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
# drop <think> token in chat template
if
input_str
.
endswith
(
'<think>
\n
'
):
input_str
=
input_str
[:
-
len
(
'<think>
\n
'
)]
input_ids
=
self
.
tokenizer
.
encode
(
input_str
,
return_tensors
=
"pt"
).
to
(
self
.
args
.
device
)
logger
.
debug
(
f
"get input ids of shape
{
input_ids
.
shape
}
"
)
return
input_ids
async
def
inference
(
self
,
local_messages
,
thread_id
:
str
,
temperature
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
):
profiler
=
Profiler
()
profiler
.
create_and_start_timer
(
"tokenize"
)
if
isinstance
(
local_messages
,
List
):
input_ids
=
self
.
format_and_tokenize_input_ids
(
thread_id
,
local_messages
)
elif
isinstance
(
local_messages
,
str
):
#local_messages = local_messages[0]['content']
input_ids
=
self
.
tokenize_prompt
(
local_messages
)
else
:
raise
ValueError
(
"local_messages should be List or str"
)
if
Config
().
user_force_think
:
token_thinks
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
"<think>
\n
"
,
add_special_tokens
=
False
)],
device
=
input_ids
.
device
)
input_ids
=
torch
.
cat
(
[
input_ids
,
token_thinks
],
dim
=
1
)
profiler
.
pause_timer
(
"tokenize"
)
profiler
.
create_and_start_timer
(
"prefill"
)
query_add
=
sched_ext
.
QueryAdd
()
query_add
.
query_token
=
input_ids
[
0
].
tolist
()
query_length
=
input_ids
[
0
].
shape
[
0
]
query_add
.
query_length
=
query_length
profiler
.
set_counter
(
"prefill"
,
query_length
)
#@TODO add server
stop_criteria
=
[
self
.
tokenizer
.
encode
(
self
.
tokenizer
.
eos_token
,
add_special_tokens
=
False
),
self
.
tokenizer
.
encode
(
"<|im_end|>"
)]
query_add
.
stop_criteria
=
stop_criteria
if
temperature
==
0
:
temperature
=
0.0001
query_add
.
sample_options
.
temperature
=
temperature
if
top_p
==
0
:
top_p
=
0.0001
query_add
.
sample_options
.
top_p
=
top_p
query_add
.
estimated_length
=
min
(
self
.
args
.
cache_lens
,
query_length
+
self
.
args
.
max_new_tokens
)
query_id
=
self
.
sched_client
.
add_query
(
query_add
)
queue
=
asyncio
.
Queue
(
maxsize
=
self
.
args
.
max_new_tokens
)
self
.
queue_map
[
query_id
]
=
queue
self
.
thread_map
[
thread_id
]
=
query_id
is_first_token
=
True
async
for
token
in
chat_stream
(
self
.
queue_map
[
query_id
],
self
.
tokenizer
):
if
is_first_token
:
is_first_token
=
False
profiler
.
pause_timer
(
"prefill"
)
profiler
.
create_and_start_timer
(
"decode"
)
profiler
.
set_counter
(
"decode"
,
0
)
if
Config
().
user_force_think
:
think
=
'<think>
\n
'
print
(
think
,
end
=
""
,
flush
=
True
)
yield
think
,
None
else
:
profiler
.
inc
(
"decode"
)
yield
token
,
None
profiler
.
pause_timer
(
"decode"
)
report_last_time_performance
(
profiler
)
yield
self
.
streamer
.
end
(),
None
if
profiler
.
get_counter
(
'decode'
)
>=
self
.
args
.
max_new_tokens
-
1
:
yield
""
,
"length"
else
:
yield
""
,
"stop"
yield
RawUsage
(
tokenize_time
=
profiler
.
get_timer_sec
(
'tokenize'
),
prefill_time
=
profiler
.
get_timer_sec
(
'prefill'
),
decode_time
=
profiler
.
get_timer_sec
(
'decode'
),
prefill_count
=
profiler
.
get_counter
(
'prefill'
),
decode_count
=
profiler
.
get_counter
(
'decode'
),
)
ktransformers/server/backend/interfaces/ktransformers.py
View file @
877aec85
...
...
@@ -211,11 +211,11 @@ class KTransformersInterface(TransformersInterface):
chunk_start
=
0
while
chunk_start
<
input_ids_length
:
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_
prefill_
size
,
input_ids_length
)
chunk_end
=
min
(
chunk_start
+
self
.
args
.
chunk_size
,
input_ids_length
)
if
self
.
cache
!=
None
:
self
.
cache
.
cur_idx
=
cache_position
[
chunk_start
:
chunk_end
]
logits
=
chunk_prefill
(
input_ids
[:,
chunk_start
:
chunk_end
],
cache_position
[
chunk_start
:
chunk_end
])
chunk_start
+=
self
.
args
.
chunk_
prefill_
size
chunk_start
+=
self
.
args
.
chunk_size
if
flashinfer_enabled
:
MLAWrapperSingleton
.
reset_buffer
()
...
...
ktransformers/server/backend/interfaces/transformers.py
View file @
877aec85
...
...
@@ -208,6 +208,8 @@ class TransformersInterface(BackendInterfaceBase):
temperature
=
self
.
model
.
generation_config
.
temperature
if
top_p
is
None
:
top_p
=
self
.
model
.
generation_config
.
top_p
if
top_p
==
0
:
top_p
=
0.0001
generation_config
,
model_kwargs
=
self
.
model
.
_prepare_generation_config
(
None
,
max_length
=
self
.
args
.
max_new_tokens
,
do_sample
=
True
,
...
...
@@ -341,7 +343,7 @@ class TransformersInterface(BackendInterfaceBase):
for
i
in
range
(
1
,
self
.
max_new_tokens
):
with
torch
.
nn
.
attention
.
sdpa_kernel
(
backends
=
[
SDPBackend
.
FLASH_ATTENTION
,
SDPBackend
.
MATH
,
SDPBackend
.
EFFICIENT_ATTENTION
]):
if
flashinfer_enabled
:
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
MLAWrapperSingleton
.
plan_all
(
None
,
None
,
None
,
self
.
active_cache_position
.
to
(
torch
.
int32
)
+
1
,
None
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
head_dim_ckv
=
self
.
model
.
config
.
kv_lora_rank
,
head_dim_kpe
=
self
.
model
.
config
.
qk_rope_head_dim
,
page_size
=
self
.
cache
.
page_size
,
sm_scale
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
softmax_scale
,
q_data_type
=
torch
.
bfloat16
,
kv_data_type
=
torch
.
bfloat16
)
...
...
ktransformers/server/balance_serve/inference/__init__.py
0 → 100644
View file @
877aec85
ktransformers/server/balance_serve/inference/config.py
0 → 100644
View file @
877aec85
'''
Date: 2024-11-07 07:30:16
LastEditors: djw
LastEditTime: 2024-11-15 14:23:26
'''
import
math
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
import
yaml
import
json
from
typing
import
Optional
class
ModelConfig
:
vocab_size
:
int
=
32000
n_layer
:
int
=
1
n_head
:
int
=
32
dim
:
int
=
4096
intermediate_size
:
int
=
18944
n_local_heads
:
int
=
8
head_dim
:
int
=
128
rope_base
:
float
=
1000000.0
norm_eps
:
float
=
1e-06
rope_scaling
:
Optional
[
dict
]
=
None
rms_norm_eps
:
float
=
1e-6
hidden_act
:
str
=
"silu"
model_path
:
str
gguf_path
:
str
optimize_rule_path
:
str
speculative_rule_path
:
str
# quantize config
quant_algorithm
:
Optional
[
str
]
=
None
quant_group_size
:
Optional
[
int
]
=
None
quant_num_bits
:
Optional
[
int
]
=
None
json_key_map
=
{
"vocab_size"
:
"vocab_size"
,
"n_layer"
:
"num_hidden_layers"
,
"n_head"
:
"num_attention_heads"
,
"dim"
:
"hidden_size"
,
"intermediate_size"
:
"intermediate_size"
,
"n_local_heads"
:
"num_key_value_heads"
,
"rope_base"
:
"rope_theta"
,
"norm_eps"
:
"norm_eps"
,
"rms_norm_eps"
:
"rms_norm_eps"
,
"hidden_act"
:
"hidden_act"
,
}
def
__init__
(
self
,
config
):
self
.
model_path
=
config
[
"model"
][
"model_path"
]
self
.
gguf_path
=
config
[
"model"
][
"gguf_path"
]
self
.
optimize_rule_path
=
config
[
"model"
][
"optimize_rule_path"
]
if
"speculative_rule_path"
in
config
[
"model"
]:
self
.
speculative_rule_path
=
config
[
"model"
][
"speculative_rule_path"
]
self
.
speculative_gguf_path
=
config
[
"model"
][
"speculative_gguf_path"
]
self
.
speculative_model_path
=
config
[
"model"
][
"speculative_model_path"
]
self
.
quant_algorithm
=
config
[
"model"
][
"quant"
][
"algorithm"
]
self
.
quant_group_size
=
config
[
"model"
][
"quant"
][
"group_size"
]
self
.
quant_num_bits
=
config
[
"model"
][
"quant"
][
"num_bits"
]
self
.
load_config
()
self
.
n_layer
=
config
[
"model"
][
"n_layers"
]
def
load_config
(
self
):
config_file
=
f
"
{
self
.
model_path
}
/config.json"
try
:
with
open
(
config_file
,
"r"
)
as
f
:
config_data
=
json
.
load
(
f
)
except
FileNotFoundError
:
raise
FileNotFoundError
(
f
"Configuration file not found at
{
config_file
}
"
)
for
attr
,
json_key
in
self
.
json_key_map
.
items
():
if
json_key
in
config_data
:
setattr
(
self
,
attr
,
config_data
[
json_key
])
else
:
setattr
(
self
,
attr
,
getattr
(
self
,
attr
))
class
ParallelConfig
:
def
__init__
(
self
,
config
,
)
->
None
:
self
.
pipeline_parallel_size
=
config
[
"parallel"
][
"pp"
]
self
.
tensor_parallel_size
=
config
[
"parallel"
][
"tp"
]
self
.
disable_custom_all_reduce
=
config
[
"parallel"
][
"disable_custom_all_reduce"
]
self
.
world_size
=
self
.
pipeline_parallel_size
*
self
.
tensor_parallel_size
class
AttnConfig
:
page_size
:
int
=
256
block_num
:
int
=
32
max_batch_token
:
int
=
256
max_batch_size
:
int
=
32
def
__init__
(
self
,
config
):
self
.
page_size
=
config
[
"attn"
][
"page_size"
]
self
.
block_num
=
config
[
"attn"
][
"block_num"
]
self
.
max_batch_token
=
config
[
"attn"
][
"max_batch_token"
]
self
.
max_batch_size
=
config
[
"attn"
][
"max_batch_size"
]
class
SamplerConfig
():
# Batched sampling params
temperatures
:
float
is_all_greedy
:
bool
def
__init__
(
self
,
config
):
self
.
temperatures
=
config
[
"sample"
][
"temperature"
]
self
.
is_all_greedy
=
True
def
load_yaml_config
(
file_path
):
with
open
(
file_path
,
"r"
)
as
f
:
return
yaml
.
safe_load
(
f
)
class
LLMConfig
:
model_config
:
ModelConfig
parallel_config
:
ParallelConfig
attn_config
:
AttnConfig
sample_config
:
SamplerConfig
config_file
:
str
def
__init__
(
self
,
config_file
):
self
.
config_file
=
config_file
config
=
load_yaml_config
(
config_file
)
self
.
model_config
=
ModelConfig
(
config
)
self
.
parallel_config
=
ParallelConfig
(
config
)
self
.
attn_config
=
AttnConfig
(
config
)
self
.
sample_config
=
SamplerConfig
(
config
)
ktransformers/server/balance_serve/inference/distributed/__init__.py
0 → 100644
View file @
877aec85
from
.communication_op
import
*
from
.parallel_state
import
*
from
.utils
import
*
ktransformers/server/balance_serve/inference/distributed/communication_op.py
0 → 100644
View file @
877aec85
"""
Date: 2024-12-11 06:02:42
LastEditors: djw
LastEditTime: 2024-12-12 09:52:06
"""
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
torch
import
torch.distributed
from
.parallel_state
import
get_tp_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
,
is_compute_bound
=
False
,
overlap
=
False
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group."""
return
get_tp_group
().
all_reduce
(
input_
,
bsz_tensor
,
is_compute_bound
=
is_compute_bound
,
overlap
=
overlap
)
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
return
get_tp_group
().
all_gather
(
input_
,
dim
)
def
tensor_model_parallel_gather
(
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
"""Gather the input tensor across model parallel group."""
return
get_tp_group
().
gather
(
input_
,
dst
,
dim
)
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
):
if
not
torch
.
distributed
.
is_initialized
():
return
tensor_dict
return
get_tp_group
().
broadcast_tensor_dict
(
tensor_dict
,
src
)
ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py
0 → 100644
View file @
877aec85
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import
ctypes
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import
torch
# noqa
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t
=
ctypes
.
c_int
cudaMemcpyKind
=
ctypes
.
c_int
class
cudaIpcMemHandle_t
(
ctypes
.
Structure
):
_fields_
=
[(
"internal"
,
ctypes
.
c_byte
*
128
)]
@
dataclass
class
Function
:
name
:
str
restype
:
Any
argtypes
:
List
[
Any
]
def
find_loaded_library
(
lib_name
)
->
Optional
[
str
]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
"""
# noqa
found
=
False
with
open
(
"/proc/self/maps"
)
as
f
:
for
line
in
f
:
if
lib_name
in
line
:
found
=
True
break
if
not
found
:
# the library is not loaded in the current process
return
None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start
=
line
.
index
(
"/"
)
path
=
line
[
start
:].
strip
()
filename
=
path
.
split
(
"/"
)[
-
1
]
assert
filename
.
rpartition
(
".so"
)[
0
].
startswith
(
lib_name
),
\
f
"Unexpected filename:
{
filename
}
for library
{
lib_name
}
"
return
path
class
CudaRTLibrary
:
exported_functions
=
[
# cudaError_t cudaSetDevice ( int device )
Function
(
"cudaSetDevice"
,
cudaError_t
,
[
ctypes
.
c_int
]),
# cudaError_t cudaDeviceSynchronize ( void )
Function
(
"cudaDeviceSynchronize"
,
cudaError_t
,
[]),
# cudaError_t cudaDeviceReset ( void )
Function
(
"cudaDeviceReset"
,
cudaError_t
,
[]),
# const char* cudaGetErrorString ( cudaError_t error )
Function
(
"cudaGetErrorString"
,
ctypes
.
c_char_p
,
[
cudaError_t
]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function
(
"cudaMalloc"
,
cudaError_t
,
[
ctypes
.
POINTER
(
ctypes
.
c_void_p
),
ctypes
.
c_size_t
]),
# cudaError_t cudaFree ( void* devPtr )
Function
(
"cudaFree"
,
cudaError_t
,
[
ctypes
.
c_void_p
]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function
(
"cudaMemset"
,
cudaError_t
,
[
ctypes
.
c_void_p
,
ctypes
.
c_int
,
ctypes
.
c_size_t
]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function
(
"cudaMemcpy"
,
cudaError_t
,
[
ctypes
.
c_void_p
,
ctypes
.
c_void_p
,
ctypes
.
c_size_t
,
cudaMemcpyKind
]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function
(
"cudaIpcGetMemHandle"
,
cudaError_t
,
[
ctypes
.
POINTER
(
cudaIpcMemHandle_t
),
ctypes
.
c_void_p
]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function
(
"cudaIpcOpenMemHandle"
,
cudaError_t
,
[
ctypes
.
POINTER
(
ctypes
.
c_void_p
),
cudaIpcMemHandle_t
,
ctypes
.
c_uint
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache
:
Dict
[
str
,
Any
]
=
{}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{}
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
if
so_file
is
None
:
so_file
=
find_loaded_library
(
"libcudart"
)
assert
so_file
is
not
None
,
\
"libcudart is not loaded in the current process"
if
so_file
not
in
CudaRTLibrary
.
path_to_library_cache
:
lib
=
ctypes
.
CDLL
(
so_file
)
CudaRTLibrary
.
path_to_library_cache
[
so_file
]
=
lib
self
.
lib
=
CudaRTLibrary
.
path_to_library_cache
[
so_file
]
if
so_file
not
in
CudaRTLibrary
.
path_to_dict_mapping
:
_funcs
=
{}
for
func
in
CudaRTLibrary
.
exported_functions
:
f
=
getattr
(
self
.
lib
,
func
.
name
)
f
.
restype
=
func
.
restype
f
.
argtypes
=
func
.
argtypes
_funcs
[
func
.
name
]
=
f
CudaRTLibrary
.
path_to_dict_mapping
[
so_file
]
=
_funcs
self
.
funcs
=
CudaRTLibrary
.
path_to_dict_mapping
[
so_file
]
def
CUDART_CHECK
(
self
,
result
:
cudaError_t
)
->
None
:
if
result
!=
0
:
error_str
=
self
.
cudaGetErrorString
(
result
)
raise
RuntimeError
(
f
"CUDART error:
{
error_str
}
"
)
def
cudaGetErrorString
(
self
,
error
:
cudaError_t
)
->
str
:
return
self
.
funcs
[
"cudaGetErrorString"
](
error
).
decode
(
"utf-8"
)
def
cudaSetDevice
(
self
,
device
:
int
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaSetDevice"
](
device
))
def
cudaDeviceSynchronize
(
self
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaDeviceSynchronize"
]())
def
cudaDeviceReset
(
self
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaDeviceReset"
]())
def
cudaMalloc
(
self
,
size
:
int
)
->
ctypes
.
c_void_p
:
devPtr
=
ctypes
.
c_void_p
()
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaMalloc"
](
ctypes
.
byref
(
devPtr
),
size
))
return
devPtr
def
cudaFree
(
self
,
devPtr
:
ctypes
.
c_void_p
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaFree"
](
devPtr
))
def
cudaMemset
(
self
,
devPtr
:
ctypes
.
c_void_p
,
value
:
int
,
count
:
int
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaMemset"
](
devPtr
,
value
,
count
))
def
cudaMemcpy
(
self
,
dst
:
ctypes
.
c_void_p
,
src
:
ctypes
.
c_void_p
,
count
:
int
)
->
None
:
cudaMemcpyDefault
=
4
kind
=
cudaMemcpyDefault
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaMemcpy"
](
dst
,
src
,
count
,
kind
))
def
cudaIpcGetMemHandle
(
self
,
devPtr
:
ctypes
.
c_void_p
)
->
cudaIpcMemHandle_t
:
handle
=
cudaIpcMemHandle_t
()
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaIpcGetMemHandle"
](
ctypes
.
byref
(
handle
),
devPtr
))
return
handle
def
cudaIpcOpenMemHandle
(
self
,
handle
:
cudaIpcMemHandle_t
)
->
ctypes
.
c_void_p
:
cudaIpcMemLazyEnablePeerAccess
=
1
devPtr
=
ctypes
.
c_void_p
()
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaIpcOpenMemHandle"
](
ctypes
.
byref
(
devPtr
),
handle
,
cudaIpcMemLazyEnablePeerAccess
))
return
devPtr
ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py
0 → 100644
View file @
877aec85
import
ctypes
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
import
server.envs
as
envs
from
server.inference.distributed.cuda_wrapper
import
CudaRTLibrary
from
server.inference.distributed.custom_all_reduce_utils
import
gpu_p2p_access_check
from
server.inference.distributed.parallel_state
import
in_the_same_node_as
from
server.inference.platforms
import
current_platform
from
server.utils
import
cuda_device_count_stateless
import
vLLMCustomAllreduce
try
:
vLLMCustomAllreduce
.
meta_size
()
custom_ar
=
True
except
Exception
:
# For AMD GPUs and CPUs
custom_ar
=
False
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
for
i
in
range
(
world_size
):
if
i
==
rank
:
continue
if
envs
.
VLLM_SKIP_P2P_CHECK
:
print
(
"Skipping P2P check and trusting the driver's P2P report."
)
return
torch
.
cuda
.
can_device_access_peer
(
rank
,
i
)
if
not
gpu_p2p_access_check
(
rank
,
i
):
return
False
return
True
def
is_weak_contiguous
(
inp
:
torch
.
Tensor
):
return
inp
.
is_contiguous
()
or
(
inp
.
storage
().
nbytes
()
-
inp
.
storage_offset
()
*
inp
.
element_size
()
==
inp
.
numel
()
*
inp
.
element_size
()
)
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
# max_size: max supported allreduce size
def
__init__
(
self
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
1024
,
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
not
custom_ar
:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
self
.
group
=
group
assert
(
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
),
"CustomAllreduce should be attached to a non-NCCL group."
if
not
all
(
in_the_same_node_as
(
group
,
source_rank
=
0
)):
# No need to initialize custom allreduce for multi-node case.
print
(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
:
print
(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
),
)
return
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
cuda_device_count_stateless
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
gather_list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
gather_list
,
tensor
,
group
=
self
.
group
)
physical_device_ids
=
[
t
.
item
()
for
t
in
gather_list
]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
assert
current_platform
.
is_cuda
()
from
server.inference.platforms.cuda
import
CudaPlatform
cuda_platform
:
CudaPlatform
=
current_platform
full_nvlink
=
cuda_platform
.
is_full_nvlink
(
physical_device_ids
)
if
world_size
>
2
and
not
full_nvlink
:
print
(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
print
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
vLLMCustomAllreduce
.
meta_size
()
+
max_size
,
group
=
group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
full_nvlink
=
full_nvlink
self
.
_ptr
=
vLLMCustomAllreduce
.
init_custom_ar
(
self
.
meta_ptrs
,
self
.
rank_data
,
rank
,
self
.
full_nvlink
)
vLLMCustomAllreduce
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
List
[
int
]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib
=
CudaRTLibrary
()
pointer
=
lib
.
cudaMalloc
(
size_in_bytes
)
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
List
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
.
value
)
# type: ignore
else
:
pointers
.
append
(
lib
.
cudaIpcOpenMemHandle
(
h
).
value
)
# type: ignore
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
List
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
lib
=
CudaRTLibrary
()
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
@
contextmanager
def
capture
(
self
):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try
:
self
.
_IS_CAPTURING
=
True
yield
finally
:
self
.
_IS_CAPTURING
=
False
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
register_graph_buffers
(
self
):
handle
,
offset
=
vLLMCustomAllreduce
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
print
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
vLLMCustomAllreduce
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
if
self
.
disabled
:
return
False
inp_size
=
inp
.
numel
()
*
inp
.
element_size
()
# custom allreduce requires input byte size to be multiples of 16
if
inp_size
%
16
!=
0
:
return
False
if
not
is_weak_contiguous
(
inp
):
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
bsz_tensor
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
,
is_compute_bound
=
False
,
overlap
=
False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if
is_compute_bound
:
sms
=
2
if
overlap
else
36
else
:
sms
=
20
if
overlap
else
36
#print("all reduce sms", sms)
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
if
registered
:
vLLMCustomAllreduce
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
,
bsz_tensor
,
block_limit
=
sms
)
else
:
vLLMCustomAllreduce
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
,
bsz_tensor
,
block_limit
=
sms
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
,
is_compute_bound
=
False
,
overlap
=
False
)
->
Optional
[
torch
.
Tensor
]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
(
input
,
bsz_tensor
=
bsz_tensor
,
registered
=
True
,
is_compute_bound
=
is_compute_bound
,
overlap
=
overlap
)
else
:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return
torch
.
empty_like
(
input
)
else
:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
bsz_tensor
=
bsz_tensor
,
registered
=
False
,
is_compute_bound
=
is_compute_bound
,
overlap
=
overlap
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
vLLMCustomAllreduce
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
def
__del__
(
self
):
self
.
close
()
ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py
0 → 100644
View file @
877aec85
import
ctypes
import
json
import
os
import
pickle
import
subprocess
import
sys
import
tempfile
from
itertools
import
product
from
typing
import
Dict
,
List
,
Optional
,
Sequence
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
server.envs
as
envs
from
server.inference.distributed.cuda_wrapper
import
CudaRTLibrary
from
server.utils
import
cuda_device_count_stateless
,
update_environment_variables
def
producer
(
batch_src
:
Sequence
[
int
],
producer_queue
,
consumer_queue
,
result_queue
,
cuda_visible_devices
:
Optional
[
str
]
=
None
,
):
if
cuda_visible_devices
is
not
None
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
cuda_visible_devices
})
lib
=
CudaRTLibrary
()
for
i
in
batch_src
:
lib
.
cudaSetDevice
(
i
)
pointer
=
lib
.
cudaMalloc
(
1024
)
lib
.
cudaMemset
(
pointer
,
1
,
1024
)
lib
.
cudaDeviceSynchronize
()
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
producer_queue
.
put
(
handle
)
open_success
=
consumer_queue
.
get
()
if
open_success
:
# use two queues to simulate barrier
producer_queue
.
put
(
0
)
consumer_queue
.
get
()
# check if the memory is modified
host_data
=
(
ctypes
.
c_char
*
1024
)()
lib
.
cudaMemcpy
(
host_data
,
pointer
,
1024
)
# type: ignore
for
i
in
range
(
1024
):
if
ord
(
host_data
[
i
])
!=
2
:
open_success
=
False
break
result_queue
.
put
(
open_success
)
lib
.
cudaDeviceReset
()
def
consumer
(
batch_tgt
:
Sequence
[
int
],
producer_queue
,
consumer_queue
,
result_queue
,
cuda_visible_devices
:
Optional
[
str
]
=
None
,
):
if
cuda_visible_devices
is
not
None
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
cuda_visible_devices
})
lib
=
CudaRTLibrary
()
for
j
in
batch_tgt
:
lib
.
cudaSetDevice
(
j
)
handle
=
producer_queue
.
get
()
open_success
=
False
try
:
pointer
=
lib
.
cudaIpcOpenMemHandle
(
handle
)
# type: ignore
open_success
=
True
except
RuntimeError
:
# cannot error out here, because the producer process
# is still waiting for the response.
pass
consumer_queue
.
put
(
open_success
)
if
open_success
:
# modify the memory
lib
.
cudaMemset
(
pointer
,
2
,
1024
)
lib
.
cudaDeviceSynchronize
()
# use two queues to simulate barrier
producer_queue
.
get
()
consumer_queue
.
put
(
0
)
# check if the memory is modified
host_data
=
(
ctypes
.
c_char
*
1024
)()
lib
.
cudaMemcpy
(
host_data
,
pointer
,
1024
)
# type: ignore
for
i
in
range
(
1024
):
if
ord
(
host_data
[
i
])
!=
2
:
open_success
=
False
break
result_queue
.
put
(
open_success
)
lib
.
cudaDeviceReset
()
def
can_actually_p2p
(
batch_src
:
Sequence
[
int
],
batch_tgt
:
Sequence
[
int
],
)
->
Sequence
[
bool
]:
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
Therefore, we have to perform a real P2P access to check if it is actually
possible.
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
GPU src --> cuda context src --> tensor src --> process src
We need to combine p2p and cuda IPC, so that:
GPU src --> cuda context src --> tensor src --> process src
|shared|
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
That is to say, process src creates a tensor in GPU src, passes IPC handle to
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
tensor in process tgt will be reflected in the tensor in process src, because
they are the same memory segment.
It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU src. That's why we need p2p access.
The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
"""
# noqa
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# make sure the processes are spawned
smp
=
mp
.
get_context
(
"spawn"
)
producer_queue
=
smp
.
Queue
()
consumer_queue
=
smp
.
Queue
()
result_queue
=
smp
.
Queue
()
p_src
=
smp
.
Process
(
target
=
producer
,
args
=
(
batch_src
,
producer_queue
,
consumer_queue
,
result_queue
,
cuda_visible_devices
,
),
)
p_tgt
=
smp
.
Process
(
target
=
consumer
,
args
=
(
batch_tgt
,
producer_queue
,
consumer_queue
,
result_queue
,
cuda_visible_devices
,
),
)
p_src
.
start
()
p_tgt
.
start
()
p_src
.
join
()
p_tgt
.
join
()
assert
p_src
.
exitcode
==
0
and
p_tgt
.
exitcode
==
0
result
:
List
[
bool
]
=
[]
for
src
,
tgt
in
zip
(
batch_src
,
batch_tgt
):
a
=
result_queue
.
get
()
b
=
result_queue
.
get
()
if
a
!=
b
:
print
(
"Two processes do not agree on the P2P access"
" status on %d -> %d, treat as disabled."
,
src
,
tgt
,
)
result
.
append
(
False
)
else
:
result
.
append
(
a
)
return
result
# why do we need this cache?
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
# if we test it every time, it will be very slow, because we need to create
# N * N * 2 processes, where N is the world size. This is very slow.
# to reduce the time, we use a cache file to store the p2p access status.
# the cache file is generated by the master process if it does not exist.
# then all the processes can read the cache file to check the p2p access status.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache
:
Optional
[
Dict
[
str
,
bool
]]
=
None
def
gpu_p2p_access_check
(
src
:
int
,
tgt
:
int
)
->
bool
:
"""Check if GPU src can access GPU tgt."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global
_gpu_p2p_access_cache
if
_gpu_p2p_access_cache
is
not
None
:
return
_gpu_p2p_access_cache
[
f
"
{
src
}
->
{
tgt
}
"
]
is_distributed
=
dist
.
is_initialized
()
num_dev
=
cuda_device_count_stateless
()
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
is
None
:
cuda_visible_devices
=
","
.
join
(
str
(
i
)
for
i
in
range
(
num_dev
))
path
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
f
"gpu_p2p_access_cache_for_
{
cuda_visible_devices
}
.json"
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
from
server.inference.distributed.parallel_state
import
get_world_group
if
(
not
is_distributed
or
get_world_group
().
local_rank
==
0
)
and
(
not
os
.
path
.
exists
(
path
)
):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
print
(
"generating GPU P2P access cache in %s"
,
path
)
cache
:
Dict
[
str
,
bool
]
=
{}
ids
=
list
(
range
(
num_dev
))
# batch of all pairs of GPUs
batch_src
,
batch_tgt
=
zip
(
*
list
(
product
(
ids
,
ids
)))
# NOTE: we use `subprocess` rather than `multiprocessing` here
# because the caller might not have `if __name__ == "__main__":`,
# in that case we cannot use spawn method in multiprocessing.
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with
tempfile
.
NamedTemporaryFile
()
as
output_file
:
input_bytes
=
pickle
.
dumps
((
batch_src
,
batch_tgt
,
output_file
.
name
))
returned
=
subprocess
.
run
(
[
sys
.
executable
,
__file__
],
input
=
input_bytes
,
capture_output
=
True
)
# check if the subprocess is successful
try
:
returned
.
check_returncode
()
except
Exception
as
e
:
# wrap raised exception to provide more information
raise
RuntimeError
(
f
"Error happened when batch testing "
f
"peer-to-peer access from
{
batch_src
}
to
{
batch_tgt
}
:
\n
"
f
"
{
returned
.
stderr
.
decode
()
}
"
)
from
e
with
open
(
output_file
.
name
,
"rb"
)
as
f
:
result
=
pickle
.
load
(
f
)
for
_i
,
_j
,
r
in
zip
(
batch_src
,
batch_tgt
,
result
):
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
r
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
cache
,
f
,
indent
=
4
)
if
is_distributed
:
get_world_group
().
barrier
()
print
(
"reading GPU P2P access cache from %s"
,
path
)
with
open
(
path
)
as
f
:
cache
=
json
.
load
(
f
)
_gpu_p2p_access_cache
=
cache
return
_gpu_p2p_access_cache
[
f
"
{
src
}
->
{
tgt
}
"
]
__all__
=
[
"gpu_p2p_access_check"
]
if
__name__
==
"__main__"
:
batch_src
,
batch_tgt
,
output_file
=
pickle
.
loads
(
sys
.
stdin
.
buffer
.
read
())
result
=
can_actually_p2p
(
batch_src
,
batch_tgt
)
with
open
(
output_file
,
"wb"
)
as
f
:
f
.
write
(
pickle
.
dumps
(
result
))
ktransformers/server/balance_serve/inference/distributed/parallel_state.py
0 → 100644
View file @
877aec85
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.
If you only need to use the distributed environment without model/pipeline
parallelism, you can skip the model parallel initialization and destruction
steps.
"""
import
contextlib
import
gc
import
pickle
import
weakref
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
import
torch
import
torch.distributed
from
torch.distributed
import
Backend
,
ProcessGroup
import
server.envs
as
envs
from
server.inference.platforms
import
current_platform
from
server.utils
import
direct_register_custom_op
,
supports_custom_op
@
dataclass
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list
:
List
[
Tuple
[
str
,
Any
]]
=
[]
tensor_list
:
List
[
torch
.
Tensor
]
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device
=
value
.
device
.
type
metadata_list
.
append
(
(
key
,
TensorMetadata
(
device
,
value
.
dtype
,
value
.
size
()))
)
tensor_list
.
append
(
value
)
else
:
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
_group_name_counter
:
Dict
[
str
,
int
]
=
{}
def
_get_unique_name
(
name
:
str
)
->
str
:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if
name
not
in
_group_name_counter
:
_group_name_counter
[
name
]
=
0
newname
=
f
"
{
name
}
:
{
_group_name_counter
[
name
]
}
"
_group_name_counter
[
name
]
+=
1
return
newname
_groups
:
Dict
[
str
,
Callable
[[],
Optional
[
"GroupCoordinator"
]]]
=
{}
def
_register_group
(
group
:
"GroupCoordinator"
)
->
None
:
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
if
supports_custom_op
():
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce_in_place
(
tensor
)
def
inplace_all_reduce_fake
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"inplace_all_reduce"
,
op_func
=
inplace_all_reduce
,
mutates_args
=
[
"tensor"
],
fake_impl
=
inplace_all_reduce_fake
,
)
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
,
bsz_tensor
:
torch
.
Tensor
,
is_compute_bound
:
bool
=
False
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce_out_place
(
tensor
,
bsz_tensor
,
is_compute_bound
=
is_compute_bound
,
overlap
=
overlap
)
def
outplace_all_reduce_fake
(
tensor
:
torch
.
Tensor
,
group_name
:
str
,
bsz_tensor
:
torch
.
Tensor
,
is_compute_bound
:
bool
=
False
,
overlap
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
direct_register_custom_op
(
op_name
=
"outplace_all_reduce"
,
op_func
=
outplace_all_reduce
,
mutates_args
=
[],
fake_impl
=
outplace_all_reduce_fake
,
)
class
GroupCoordinator
:
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
a specific implementation (e.g. switch allreduce implementation
based on the tensor size and cuda graph mode).
"""
# available attributes:
rank
:
int
# global rank
ranks
:
List
[
int
]
# global ranks in the group
world_size
:
int
# size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
# Process | Node | Rank | Local Rank | Rank in Group
# 0 | 0 | 0 | 0 | 0
# 1 | 0 | 1 | 1 | 1
# 2 | 1 | 2 | 0 | 2
# 3 | 1 | 3 | 1 | 3
local_rank
:
int
# local rank used to assign devices
rank_in_group
:
int
# rank inside the group
cpu_group
:
ProcessGroup
# group for CPU communication
device_group
:
ProcessGroup
# group for device communication
use_pynccl
:
bool
# a hint of whether to use PyNccl
use_custom_allreduce
:
bool
# a hint of whether to use CustomAllreduce
# communicators are only created for world size > 1
pynccl_comm
:
Optional
[
Any
]
# PyNccl communicator
ca_comm
:
Optional
[
Any
]
# Custom allreduce communicator
mq_broadcaster
:
Optional
[
Any
]
# shared memory broadcaster
def
__init__
(
self
,
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
torch_distributed_backend
:
Union
[
str
,
Backend
],
use_pynccl
:
bool
,
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_hpu_communicator
:
bool
,
use_xpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
group_name
=
group_name
or
"anonymous"
self
.
unique_name
=
_get_unique_name
(
group_name
)
_register_group
(
self
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
local_rank
=
local_rank
self
.
device_group
=
None
self
.
cpu_group
=
None
for
ranks
in
group_ranks
:
device_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
self
.
rank
in
ranks
:
self
.
ranks
=
ranks
self
.
world_size
=
len
(
ranks
)
self
.
rank_in_group
=
ranks
.
index
(
self
.
rank
)
self
.
device_group
=
device_group
self
.
cpu_group
=
cpu_group
assert
self
.
cpu_group
is
not
None
assert
self
.
device_group
is
not
None
assert
current_platform
.
is_cuda_alike
()
if
current_platform
.
is_cuda_alike
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
use_pynccl
=
use_pynccl
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_tpu_communicator
=
use_tpu_communicator
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
# lazy import to avoid documentation build error
from
server.inference.distributed.custom_all_reduce
import
CustomAllreduce
from
server.inference.distributed.pynccl
import
PyNcclCommunicator
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
# if use_pynccl and self.world_size > 1:
# self.pynccl_comm = PyNcclCommunicator(
# group=self.cpu_group,
# device=self.device,
# )
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
#### we assume we won't use tpu or hpu or xpu or messagequeue broadcast
# from vllm.distributed.device_communicators.tpu_communicator import (
# TpuCommunicator)
# self.tpu_communicator: Optional[TpuCommunicator] = None
# if use_tpu_communicator and self.world_size > 1:
# self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
self
.
tpu_communicator
=
None
# from vllm.distributed.device_communicators.hpu_communicator import (
# HpuCommunicator)
# self.hpu_communicator: Optional[HpuCommunicator]
# if use_hpu_communicator and self.world_size > 1:
# self.hpu_communicator = HpuCommunicator(group=self.device_group)
self
.
hpu_communicator
=
None
# from vllm.distributed.device_communicators.xpu_communicator import (
# XpuCommunicator)
# self.xpu_communicator: Optional[XpuCommunicator]
# if use_xpu_communicator and self.world_size > 1:
# self.xpu_communicator = XpuCommunicator(group=self.device_group)
self
.
xpu_communicator
=
None
# from vllm.distributed.device_communicators.shm_broadcast import (
# MessageQueue)
# self.mq_broadcaster: Optional[MessageQueue] = None
# if use_message_queue_broadcaster and self.world_size > 1:
# self.mq_broadcaster = MessageQueue.create_from_process_group(
# self.cpu_group, 1 << 22, 6)
self
.
mq_broadcaster
=
None
@
property
def
first_rank
(
self
):
"""Return the global rank of the first process in the group"""
return
self
.
ranks
[
0
]
@
property
def
last_rank
(
self
):
"""Return the global rank of the last process in the group"""
return
self
.
ranks
[
-
1
]
@
property
def
is_first_rank
(
self
):
"""Return whether the caller is the first process in the group"""
return
self
.
rank
==
self
.
first_rank
@
property
def
is_last_rank
(
self
):
"""Return whether the caller is the last process in the group"""
return
self
.
rank
==
self
.
last_rank
@
property
def
next_rank
(
self
):
"""Return the global rank of the process that follows the caller"""
rank_in_group
=
self
.
rank_in_group
world_size
=
self
.
world_size
return
self
.
ranks
[(
rank_in_group
+
1
)
%
world_size
]
@
property
def
prev_rank
(
self
):
"""Return the global rank of the process that precedes the caller"""
rank_in_group
=
self
.
rank_in_group
world_size
=
self
.
world_size
return
self
.
ranks
[(
rank_in_group
-
1
)
%
world_size
]
@
contextmanager
def
graph_capture
(
self
,
graph_capture_context
:
Optional
[
GraphCaptureContext
]
=
None
):
if
graph_capture_context
is
None
:
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
else
:
stream
=
graph_capture_context
.
stream
ca_comm
=
self
.
ca_comm
maybe_ca_context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream
=
torch
.
cuda
.
current_stream
()
if
curr_stream
!=
stream
:
stream
.
wait_stream
(
curr_stream
)
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
# PyTorch NCCL. We always prioritize using custom all-reduce
# kernel but fall back to PyTorch or pynccl if it is
# disabled or not supported.
pynccl_comm
=
self
.
pynccl_comm
maybe_pynccl_context
:
Any
if
not
pynccl_comm
:
maybe_pynccl_context
=
nullcontext
()
else
:
maybe_pynccl_context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()
)
with
maybe_pynccl_context
:
yield
graph_capture_context
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
,
is_compute_bound
=
False
,
overlap
=
False
)
->
torch
.
Tensor
:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
if
input_
.
is_cpu
:
import
intel_extension_for_pytorch
as
ipex
ipex
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
if
not
supports_custom_op
():
self
.
_all_reduce_in_place
(
input_
)
return
input_
if
self
.
tpu_communicator
is
not
None
and
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
return
self
.
tpu_communicator
.
all_reduce
(
input_
)
if
self
.
hpu_communicator
is
not
None
and
not
self
.
hpu_communicator
.
disabled
:
return
self
.
hpu_communicator
.
all_reduce
(
input_
)
if
self
.
xpu_communicator
is
not
None
and
not
self
.
xpu_communicator
.
disabled
:
return
self
.
xpu_communicator
.
all_reduce
(
input_
)
if
(
self
.
ca_comm
is
not
None
and
not
self
.
ca_comm
.
disabled
and
self
.
ca_comm
.
should_custom_ar
(
input_
)
):
return
torch
.
ops
.
vllm
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
,
bsz_tensor
=
bsz_tensor
,
is_compute_bound
=
is_compute_bound
,
overlap
=
overlap
)
else
:
#assert self.ca_comm is not None
#assert not self.ca_comm.disabled
#assert self.ca_comm.should_custom_ar(input_)
torch
.
ops
.
vllm
.
inplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
return
input_
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
,
bsz_tensor
:
torch
.
Tensor
,
is_compute_bound
=
False
,
overlap
=
False
)
->
torch
.
Tensor
:
ca_comm
=
self
.
ca_comm
assert
ca_comm
is
not
None
assert
not
ca_comm
.
disabled
out
=
ca_comm
.
custom_all_reduce
(
input_
,
bsz_tensor
,
is_compute_bound
=
is_compute_bound
,
overlap
=
overlap
)
assert
out
is
not
None
return
out
def
_all_reduce_in_place
(
self
,
input_
:
torch
.
Tensor
)
->
None
:
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
all_reduce
(
input_
)
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
(
-
input_
.
dim
()
<=
dim
<
input_
.
dim
()
),
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_gather
(
input_
,
dim
)
# For HPUs, use HPU communicator.
hpu_comm
=
self
.
hpu_communicator
if
hpu_comm
is
not
None
and
not
hpu_comm
.
disabled
:
return
hpu_comm
.
all_gather
(
input_
,
dim
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size
=
(
input_size
[
0
]
*
world_size
,)
+
input_size
[
1
:]
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
self
.
device_group
)
# Reshape
output_tensor
=
output_tensor
.
reshape
((
world_size
,)
+
input_size
)
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],)
+
input_size
[
dim
+
1
:]
)
return
output_tensor
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
(
-
input_
.
dim
()
<=
dim
<
input_
.
dim
()
),
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
if
self
.
xpu_communicator
is
not
None
and
not
self
.
xpu_communicator
.
disabled
:
return
self
.
xpu_communicator
.
gather
(
input_
,
self
.
rank_in_group
,
dst
,
dim
)
# Allocate output tensor.
if
self
.
rank_in_group
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
device_group
)
if
self
.
rank_in_group
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
broadcast
(
self
,
input_
:
torch
.
Tensor
,
src
:
int
=
0
):
"""Broadcast the input tensor.
NOTE: `src` is the local rank of the source rank.
"""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
self
.
ranks
[
src
],
group
=
self
.
device_group
)
return
input_
def
broadcast_object
(
self
,
obj
:
Optional
[
Any
]
=
None
,
src
:
int
=
0
):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
obj
if
self
.
mq_broadcaster
is
not
None
:
assert
src
==
0
,
"Message queue broadcaster only supports src=0"
return
self
.
mq_broadcaster
.
broadcast_object
(
obj
)
if
self
.
rank_in_group
==
src
:
torch
.
distributed
.
broadcast_object_list
(
[
obj
],
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
return
obj
else
:
recv
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
return
recv
[
0
]
def
broadcast_object_list
(
self
,
obj_list
:
List
[
Any
],
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input object list.
NOTE: `src` is the local rank of the source rank.
"""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
self
.
ranks
[
src
],
group
=
self
.
device_group
)
return
obj_list
def
send_object
(
self
,
obj
:
Any
,
dst
:
int
)
->
None
:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
assert
dst
!=
self
.
rank_in_group
,
(
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
# Serialize object to tensor and get the size as well
object_tensor
=
torch
.
frombuffer
(
pickle
.
dumps
(
obj
),
dtype
=
torch
.
uint8
)
size_tensor
=
torch
.
tensor
(
[
object_tensor
.
numel
()],
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Send object size
torch
.
distributed
.
send
(
size_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
)
# Send object
torch
.
distributed
.
send
(
object_tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
self
.
cpu_group
)
return
None
def
recv_object
(
self
,
src
:
int
)
->
Any
:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
assert
(
src
!=
self
.
rank_in_group
),
"Invalid source rank. Source rank is the same as the current rank."
size_tensor
=
torch
.
empty
(
1
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
# Receive object size
rank_size
=
torch
.
distributed
.
recv
(
size_tensor
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
# Tensor to receive serialized objects into.
object_tensor
=
torch
.
empty
(
# type: ignore[call-overload]
size_tensor
.
item
(),
# type: ignore[arg-type]
dtype
=
torch
.
uint8
,
device
=
"cpu"
,
)
rank_object
=
torch
.
distributed
.
recv
(
object_tensor
,
src
=
self
.
ranks
[
src
],
group
=
self
.
cpu_group
)
assert
(
rank_object
==
rank_size
),
"Received object sender rank does not match the size sender rank."
obj
=
pickle
.
loads
(
object_tensor
.
numpy
().
tobytes
())
return
obj
def
broadcast_tensor_dict
(
self
,
tensor_dict
:
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
metadata_group
:
Optional
[
ProcessGroup
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
tensor_dict
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
rank_in_group
=
self
.
rank_in_group
if
rank_in_group
==
src
:
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
assert
isinstance
(
tensor_dict
,
dict
),
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self
.
broadcast_object
(
metadata_list
,
src
=
src
)
async_handles
=
[]
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
metadata_group
,
async_op
=
True
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
for
async_handle
in
async_handles
:
async_handle
.
wait
()
else
:
metadata_list
=
self
.
broadcast_object
(
None
,
src
=
src
)
tensor_dict
=
{}
async_handles
=
[]
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
tensor_dict
[
key
]
=
tensor
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
metadata_group
,
async_op
=
True
,
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
async_handle
.
wait
()
return
tensor_dict
def
send_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
dst
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
tensor_dict
all_gather_size
=
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
if
dst
is
None
:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
assert
dst
<
self
.
world_size
,
f
"Invalid dst rank (
{
dst
}
)"
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
assert
isinstance
(
tensor_dict
,
dict
),
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self
.
send_object
(
metadata_list
,
dst
=
dst
)
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
:
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
send
(
tensor
,
dst
=
self
.
ranks
[
dst
],
group
=
group
)
return
None
def
recv_tensor_dict
(
self
,
src
:
Optional
[
int
]
=
None
,
all_gather_group
:
Optional
[
"GroupCoordinator"
]
=
None
,
)
->
Optional
[
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if
not
torch
.
distributed
.
is_initialized
()
or
self
.
world_size
==
1
:
return
None
all_gather_size
=
1
if
all_gather_group
is
None
else
all_gather_group
.
world_size
all_gather_rank
=
(
0
if
all_gather_group
is
None
else
all_gather_group
.
rank_in_group
)
group
=
self
.
device_group
metadata_group
=
self
.
cpu_group
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
assert
src
<
self
.
world_size
,
f
"Invalid src rank (
{
src
}
)"
recv_metadata_list
=
self
.
recv_object
(
src
=
src
)
tensor_dict
:
Dict
[
str
,
Any
]
=
{}
for
key
,
value
in
recv_metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
tensor_dict
[
key
]
=
tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather
=
(
all_gather_group
is
not
None
and
tensor
.
numel
()
%
all_gather_size
==
0
)
if
use_all_gather
:
orig_shape
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
all_gather_size
,
-
1
)[
all_gather_rank
]
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
metadata_group
)
else
:
# use group for GPU tensors
torch
.
distributed
.
recv
(
tensor
,
src
=
self
.
ranks
[
src
],
group
=
group
)
if
use_all_gather
:
# do the allgather
tensor
=
all_gather_group
.
all_gather
(
tensor
,
dim
=
0
)
# type: ignore
tensor
=
tensor
.
reshape
(
orig_shape
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
return
tensor_dict
def
barrier
(
self
):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
"""
torch
.
distributed
.
barrier
(
group
=
self
.
cpu_group
)
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
Optional
[
int
]
=
None
)
->
None
:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if
dst
is
None
:
dst
=
(
self
.
rank_in_group
+
1
)
%
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
send
(
tensor
,
dst
)
else
:
torch
.
distributed
.
send
(
tensor
,
self
.
ranks
[
dst
],
self
.
device_group
)
def
recv
(
self
,
size
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
src
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if
src
is
None
:
src
=
(
self
.
rank_in_group
-
1
)
%
self
.
world_size
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
pynccl_comm
.
recv
(
tensor
,
src
)
else
:
torch
.
distributed
.
recv
(
tensor
,
self
.
ranks
[
src
],
self
.
device_group
)
return
tensor
def
destroy
(
self
):
if
self
.
device_group
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
self
.
device_group
)
self
.
device_group
=
None
if
self
.
cpu_group
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
self
.
cpu_group
)
self
.
cpu_group
=
None
if
self
.
pynccl_comm
is
not
None
:
self
.
pynccl_comm
=
None
if
self
.
ca_comm
is
not
None
:
self
.
ca_comm
=
None
if
self
.
mq_broadcaster
is
not
None
:
self
.
mq_broadcaster
=
None
_WORLD
:
Optional
[
GroupCoordinator
]
=
None
def
get_world_group
()
->
GroupCoordinator
:
assert
_WORLD
is
not
None
,
"world group is not initialized"
return
_WORLD
def
init_world_group
(
ranks
:
List
[
int
],
local_rank
:
int
,
backend
:
str
)
->
GroupCoordinator
:
return
GroupCoordinator
(
group_ranks
=
[
ranks
],
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
use_hpu_communicator
=
False
,
use_xpu_communicator
=
False
,
group_name
=
"world"
,
)
def
init_model_parallel_group
(
group_ranks
:
List
[
List
[
int
]],
local_rank
:
int
,
backend
:
str
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
return
GroupCoordinator
(
group_ranks
=
group_ranks
,
local_rank
=
local_rank
,
torch_distributed_backend
=
backend
,
use_pynccl
=
True
,
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_hpu_communicator
=
True
,
use_xpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
)
_TP
:
Optional
[
GroupCoordinator
]
=
None
def
get_tp_group
()
->
GroupCoordinator
:
assert
_TP
is
not
None
,
"tensor model parallel group is not initialized"
return
_TP
# kept for backward compatibility
get_tensor_model_parallel_group
=
get_tp_group
_PP
:
Optional
[
GroupCoordinator
]
=
None
def
get_pp_group
()
->
GroupCoordinator
:
assert
_PP
is
not
None
,
"pipeline model parallel group is not initialized"
return
_PP
# kept for backward compatibility
get_pipeline_model_parallel_group
=
get_pp_group
@
contextmanager
def
graph_capture
():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
with
get_tp_group
().
graph_capture
()
as
context
,
get_pp_group
().
graph_capture
(
context
):
yield
context
_ENABLE_CUSTOM_ALL_REDUCE
=
True
def
set_custom_all_reduce
(
enable
:
bool
):
global
_ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
def
init_distributed_environment
(
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
distributed_init_method
:
str
=
"env://"
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
):
print
(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s"
,
world_size
,
rank
,
local_rank
,
distributed_init_method
,
backend
,
)
if
not
torch
.
distributed
.
is_initialized
():
assert
distributed_init_method
is
not
None
,
(
"distributed_init_method must be provided when initializing "
"distributed environment"
)
# this backend is used for WORLD
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
init_method
=
distributed_init_method
,
world_size
=
world_size
,
rank
=
rank
,
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if
local_rank
==
-
1
:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if
distributed_init_method
==
"env://"
:
local_rank
=
envs
.
LOCAL_RANK
else
:
local_rank
=
rank
global
_WORLD
if
_WORLD
is
None
:
ranks
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
_WORLD
=
init_world_group
(
ranks
,
local_rank
,
backend
)
else
:
assert
(
_WORLD
.
world_size
==
torch
.
distributed
.
get_world_size
()
),
"world group already initialized with a different world size"
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
if
world_size
!=
tensor_model_parallel_size
*
pipeline_model_parallel_size
:
raise
RuntimeError
(
f
"world_size (
{
world_size
}
) is not equal to "
f
"tensor_model_parallel_size (
{
tensor_model_parallel_size
}
) x "
f
"pipeline_model_parallel_size (
{
pipeline_model_parallel_size
}
)"
)
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
global
_TP
assert
_TP
is
None
,
"tensor model parallel group is already initialized"
group_ranks
=
[]
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
list
(
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
)
group_ranks
.
append
(
ranks
)
# message queue broadcaster is only used in tensor model parallel group
_TP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
True
,
group_name
=
"tp"
,
)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel_size
global
_PP
assert
_PP
is
None
,
"pipeline model parallel group is already initialized"
group_ranks
=
[]
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
list
(
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
))
group_ranks
.
append
(
ranks
)
# pipeline parallel does not need custom allreduce
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_custom_allreduce
=
False
,
group_name
=
"pp"
,
)
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
,
backend
)
return
assert
get_tensor_model_parallel_world_size
()
==
tensor_model_parallel_size
,
(
"tensor parallel group already initialized, but of unexpected size: "
f
"
{
get_tensor_model_parallel_world_size
()
=
}
vs. "
f
"
{
tensor_model_parallel_size
=
}
"
)
pp_world_size
=
get_pp_group
().
world_size
assert
pp_world_size
==
pipeline_model_parallel_size
,
(
"pipeline parallel group already initialized, but of unexpected size: "
f
"
{
pp_world_size
=
}
vs. "
f
"
{
pipeline_model_parallel_size
=
}
"
)
def
model_parallel_is_initialized
():
"""Check if tensor and pipeline parallel groups are initialized."""
return
_TP
is
not
None
and
_PP
is
not
None
_TP_STATE_PATCHED
=
False
@
contextmanager
def
patch_tensor_parallel_group
(
tp_group
:
GroupCoordinator
):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global
_TP_STATE_PATCHED
assert
not
_TP_STATE_PATCHED
,
"Should not call when it's already patched"
_TP_STATE_PATCHED
=
True
old_tp_group
=
get_tp_group
()
global
_TP
_TP
=
tp_group
try
:
yield
finally
:
# restore the original state
_TP_STATE_PATCHED
=
False
_TP
=
old_tp_group
def
get_tensor_model_parallel_world_size
():
"""Return world size for the tensor model parallel group."""
return
get_tp_group
().
world_size
def
get_tensor_model_parallel_rank
():
"""Return my rank for the tensor model parallel group."""
return
get_tp_group
().
rank_in_group
def
destroy_model_parallel
():
"""Set the groups to none and destroy them."""
global
_TP
if
_TP
:
_TP
.
destroy
()
_TP
=
None
global
_PP
if
_PP
:
_PP
.
destroy
()
_PP
=
None
def
destroy_distributed_environment
():
global
_WORLD
if
_WORLD
:
_WORLD
.
destroy
()
_WORLD
=
None
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
destroy_process_group
()
def
cleanup_dist_env_and_memory
(
shutdown_ray
:
bool
=
False
):
destroy_model_parallel
()
destroy_distributed_environment
()
with
contextlib
.
suppress
(
AssertionError
):
torch
.
distributed
.
destroy_process_group
()
if
shutdown_ray
:
import
ray
# Lazy import Ray
ray
.
shutdown
()
gc
.
collect
()
if
not
current_platform
.
is_cpu
():
torch
.
cuda
.
empty_cache
()
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
assert
(
torch
.
distributed
.
get_backend
(
pg
)
!=
torch
.
distributed
.
Backend
.
NCCL
),
"in_the_same_node_as should be tested with a non-NCCL group."
# local rank inside the group
rank
=
torch
.
distributed
.
get_rank
(
group
=
pg
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
pg
)
# local tensor in each process to store the result
is_in_the_same_node
=
torch
.
tensor
([
0
]
*
world_size
,
dtype
=
torch
.
int32
)
# global ranks of the processes in the group
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
pg
)
magic_message
=
b
"magic_message"
shm
=
None
try
:
with
contextlib
.
suppress
(
OSError
):
if
rank
==
source_rank
:
# create a shared memory segment
shm
=
shared_memory
.
SharedMemory
(
create
=
True
,
size
=
128
)
shm
.
buf
[:
len
(
magic_message
)]
=
magic_message
torch
.
distributed
.
broadcast_object_list
(
[
shm
.
name
],
src
=
ranks
[
source_rank
],
group
=
pg
)
is_in_the_same_node
[
rank
]
=
1
else
:
# try to open the shared memory segment
recv
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv
,
src
=
ranks
[
source_rank
],
group
=
pg
)
name
=
recv
[
0
]
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with
patch
(
"multiprocessing.resource_tracker.register"
,
lambda
*
args
,
**
kwargs
:
None
,
):
shm
=
shared_memory
.
SharedMemory
(
name
=
name
)
if
shm
.
buf
[:
len
(
magic_message
)]
==
magic_message
:
is_in_the_same_node
[
rank
]
=
1
except
Exception
as
e
:
print
(
"Error ignored in is_in_the_same_node: %s"
,
e
)
finally
:
if
shm
:
shm
.
close
()
torch
.
distributed
.
barrier
(
group
=
pg
)
# clean up the shared memory segment
with
contextlib
.
suppress
(
OSError
):
if
rank
==
source_rank
and
shm
:
shm
.
unlink
()
torch
.
distributed
.
all_reduce
(
is_in_the_same_node
,
group
=
pg
)
return
[
x
==
1
for
x
in
is_in_the_same_node
.
tolist
()]
ktransformers/server/balance_serve/inference/distributed/pynccl.py
0 → 100644
View file @
877aec85
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Union
# ===================== import region =====================
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
server.inference.distributed.pynccl_wrapper
import
(
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
,
ncclRedOpTypeEnum
,
ncclUniqueId
,
)
from
server.inference.distributed.utils
import
StatelessProcessGroup
class
PyNcclCommunicator
:
def
__init__
(
self
,
group
:
Union
[
ProcessGroup
,
StatelessProcessGroup
],
device
:
Union
[
int
,
str
,
torch
.
device
],
library_path
:
Optional
[
str
]
=
None
,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if
not
isinstance
(
group
,
StatelessProcessGroup
):
assert
dist
.
is_initialized
()
assert
(
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
),
"PyNcclCommunicator should be attached to a non-NCCL group."
# note: this rank is the rank in the group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
else
:
self
.
rank
=
group
.
rank
self
.
world_size
=
group
.
world_size
self
.
group
=
group
# if world_size == 1, no need to create communicator
if
self
.
world_size
==
1
:
self
.
available
=
False
self
.
disabled
=
True
self
.
stream
=
None
return
try
:
self
.
nccl
=
NCCLLibrary
(
library_path
)
except
Exception
:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self
.
available
=
False
self
.
disabled
=
True
self
.
stream
=
None
return
self
.
available
=
True
self
.
disabled
=
False
print
(
"vLLM is using nccl==%s"
,
self
.
nccl
.
ncclGetVersion
())
if
self
.
rank
==
0
:
# get the unique id from NCCL
self
.
unique_id
=
self
.
nccl
.
ncclGetUniqueId
()
else
:
# construct an empty unique id
self
.
unique_id
=
ncclUniqueId
()
if
not
isinstance
(
group
,
StatelessProcessGroup
):
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
ranks
=
dist
.
get_process_group_ranks
(
group
)
# arg `src` in `broadcast` is the global rank
dist
.
broadcast
(
tensor
,
src
=
ranks
[
0
],
group
=
group
)
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
else
:
self
.
unique_id
=
group
.
broadcast_obj
(
self
.
unique_id
,
src
=
0
)
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with
torch
.
cuda
.
device
(
device
):
self
.
comm
:
ncclComm_t
=
self
.
nccl
.
ncclCommInitRank
(
self
.
world_size
,
self
.
unique_id
,
self
.
rank
)
self
.
stream
=
torch
.
cuda
.
Stream
()
# A small all_reduce for warmup.
data
=
torch
.
zeros
(
1
,
device
=
device
)
self
.
all_reduce
(
data
)
self
.
stream
.
synchronize
()
del
data
# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# when we are using CUDA graph.
self
.
disabled
=
True
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
if
self
.
disabled
:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
self
.
nccl
.
ncclAllReduce
(
buffer_type
(
tensor
.
data_ptr
()),
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
self
.
nccl
.
ncclSend
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
dst
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
def
recv
(
self
,
tensor
:
torch
.
Tensor
,
src
:
int
,
stream
=
None
):
if
self
.
disabled
:
return
assert
tensor
.
device
==
self
.
device
,
(
f
"this nccl communicator is created to work on
{
self
.
device
}
, "
f
"but the input tensor is on
{
tensor
.
device
}
"
)
if
stream
is
None
:
stream
=
self
.
stream
self
.
nccl
.
ncclRecv
(
buffer_type
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
src
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
):
"""
A context manager to change the state of the communicator.
"""
if
enable
is
None
:
# guess a default value when not specified
enable
=
self
.
available
if
stream
is
None
:
stream
=
self
.
stream
old_disable
=
self
.
disabled
old_stream
=
self
.
stream
self
.
stream
=
stream
self
.
disabled
=
not
enable
yield
self
.
disabled
=
old_disable
self
.
stream
=
old_stream
ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py
0 → 100644
View file @
877aec85
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import
ctypes
import
platform
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.distributed
import
ReduceOp
from
server.utils
import
find_nccl_library
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t
=
ctypes
.
c_int
ncclComm_t
=
ctypes
.
c_void_p
class
ncclUniqueId
(
ctypes
.
Structure
):
_fields_
=
[(
"internal"
,
ctypes
.
c_byte
*
128
)]
cudaStream_t
=
ctypes
.
c_void_p
buffer_type
=
ctypes
.
c_void_p
ncclDataType_t
=
ctypes
.
c_int
class
ncclDataTypeEnum
:
ncclInt8
=
0
ncclChar
=
0
ncclUint8
=
1
ncclInt32
=
2
ncclInt
=
2
ncclUint32
=
3
ncclInt64
=
4
ncclUint64
=
5
ncclFloat16
=
6
ncclHalf
=
6
ncclFloat32
=
7
ncclFloat
=
7
ncclFloat64
=
8
ncclDouble
=
8
ncclBfloat16
=
9
ncclNumTypes
=
10
@
classmethod
def
from_torch
(
cls
,
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
int8
:
return
cls
.
ncclInt8
if
dtype
==
torch
.
uint8
:
return
cls
.
ncclUint8
if
dtype
==
torch
.
int32
:
return
cls
.
ncclInt32
if
dtype
==
torch
.
int64
:
return
cls
.
ncclInt64
if
dtype
==
torch
.
float16
:
return
cls
.
ncclFloat16
if
dtype
==
torch
.
float32
:
return
cls
.
ncclFloat32
if
dtype
==
torch
.
float64
:
return
cls
.
ncclFloat64
if
dtype
==
torch
.
bfloat16
:
return
cls
.
ncclBfloat16
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
ncclRedOp_t
=
ctypes
.
c_int
class
ncclRedOpTypeEnum
:
ncclSum
=
0
ncclProd
=
1
ncclMax
=
2
ncclMin
=
3
ncclAvg
=
4
ncclNumOps
=
5
@
classmethod
def
from_torch
(
cls
,
op
:
ReduceOp
)
->
int
:
if
op
==
ReduceOp
.
SUM
:
return
cls
.
ncclSum
if
op
==
ReduceOp
.
PRODUCT
:
return
cls
.
ncclProd
if
op
==
ReduceOp
.
MAX
:
return
cls
.
ncclMax
if
op
==
ReduceOp
.
MIN
:
return
cls
.
ncclMin
if
op
==
ReduceOp
.
AVG
:
return
cls
.
ncclAvg
raise
ValueError
(
f
"Unsupported op:
{
op
}
"
)
@
dataclass
class
Function
:
name
:
str
restype
:
Any
argtypes
:
List
[
Any
]
class
NCCLLibrary
:
exported_functions
=
[
# const char* ncclGetErrorString(ncclResult_t result)
Function
(
"ncclGetErrorString"
,
ctypes
.
c_char_p
,
[
ncclResult_t
]),
# ncclResult_t ncclGetVersion(int *version);
Function
(
"ncclGetVersion"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ctypes
.
c_int
)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function
(
"ncclGetUniqueId"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ncclUniqueId
)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function
(
"ncclCommInitRank"
,
ncclResult_t
,
[
ctypes
.
POINTER
(
ncclComm_t
),
ctypes
.
c_int
,
ncclUniqueId
,
ctypes
.
c_int
]),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclAllReduce"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function
(
"ncclSend"
,
ncclResult_t
,
[
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function
(
"ncclRecv"
,
ncclResult_t
,
[
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function
(
"ncclCommDestroy"
,
ncclResult_t
,
[
ncclComm_t
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache
:
Dict
[
str
,
Any
]
=
{}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{}
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
so_file
=
so_file
or
find_nccl_library
()
try
:
if
so_file
not
in
NCCLLibrary
.
path_to_dict_mapping
:
lib
=
ctypes
.
CDLL
(
so_file
)
NCCLLibrary
.
path_to_library_cache
[
so_file
]
=
lib
self
.
lib
=
NCCLLibrary
.
path_to_library_cache
[
so_file
]
except
Exception
as
e
:
print
(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
"environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path."
,
so_file
,
platform
.
platform
())
raise
e
if
so_file
not
in
NCCLLibrary
.
path_to_dict_mapping
:
_funcs
:
Dict
[
str
,
Any
]
=
{}
for
func
in
NCCLLibrary
.
exported_functions
:
f
=
getattr
(
self
.
lib
,
func
.
name
)
f
.
restype
=
func
.
restype
f
.
argtypes
=
func
.
argtypes
_funcs
[
func
.
name
]
=
f
NCCLLibrary
.
path_to_dict_mapping
[
so_file
]
=
_funcs
self
.
_funcs
=
NCCLLibrary
.
path_to_dict_mapping
[
so_file
]
def
ncclGetErrorString
(
self
,
result
:
ncclResult_t
)
->
str
:
return
self
.
_funcs
[
"ncclGetErrorString"
](
result
).
decode
(
"utf-8"
)
def
NCCL_CHECK
(
self
,
result
:
ncclResult_t
)
->
None
:
if
result
!=
0
:
error_str
=
self
.
ncclGetErrorString
(
result
)
raise
RuntimeError
(
f
"NCCL error:
{
error_str
}
"
)
def
ncclGetVersion
(
self
)
->
str
:
version
=
ctypes
.
c_int
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGetVersion"
](
ctypes
.
byref
(
version
)))
version_str
=
str
(
version
.
value
)
# something like 21903 --> "2.19.3"
major
=
version_str
[
0
].
lstrip
(
"0"
)
minor
=
version_str
[
1
:
3
].
lstrip
(
"0"
)
patch
=
version_str
[
3
:].
lstrip
(
"0"
)
return
f
"
{
major
}
.
{
minor
}
.
{
patch
}
"
def
ncclGetUniqueId
(
self
)
->
ncclUniqueId
:
unique_id
=
ncclUniqueId
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGetUniqueId"
](
ctypes
.
byref
(
unique_id
)))
return
unique_id
def
ncclCommInitRank
(
self
,
world_size
:
int
,
unique_id
:
ncclUniqueId
,
rank
:
int
)
->
ncclComm_t
:
comm
=
ncclComm_t
()
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommInitRank"
](
ctypes
.
byref
(
comm
),
world_size
,
unique_id
,
rank
))
return
comm
def
ncclAllReduce
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
op
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclAllReduce"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
comm
,
stream
))
def
ncclSend
(
self
,
sendbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
dest
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclSend"
](
sendbuff
,
count
,
datatype
,
dest
,
comm
,
stream
))
def
ncclRecv
(
self
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
src
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclRecv"
](
recvbuff
,
count
,
datatype
,
src
,
comm
,
stream
))
def
ncclCommDestroy
(
self
,
comm
:
ncclComm_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommDestroy"
](
comm
))
__all__
=
[
"NCCLLibrary"
,
"ncclDataTypeEnum"
,
"ncclRedOpTypeEnum"
,
"ncclUniqueId"
,
"ncclComm_t"
,
"cudaStream_t"
,
"buffer_type"
]
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment