Unverified Commit a41d2163 authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1013 from kvcache-ai/work-concurrent

In v0.2.4 version, we’ve added highly desired multi-concurrency support to the community through a major refactor of the whole architecture.
parents f142f4df 4ed9744e
......@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......
......@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......
......@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace:
class: ktransformers.operators.layernorm.RMSNorm
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace:
class: ktransformers.operators.mlp.kDeepseekV3MLP
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
......@@ -38,7 +38,7 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace:
class: ktransformers.operators.layernorm.RMSNorm
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace:
class: ktransformers.operators.mlp.kDeepseekV3MLP
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV4
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
import argparse
from ktransformers.server.backend.args import ConfigArgs, default_args
from ktransformers.util.utils import get_free_ports
class ArgumentParser:
def __init__(self, cfg):
......@@ -16,20 +16,18 @@ class ArgumentParser:
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
parser.add_argument("--model_dir", type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument("--model_path", type=str, default=self.cfg.model_path)
parser.add_argument(
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_prefill_size", type=int, default=8192)
parser.add_argument("--backend_type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_size", type=int, default=self.cfg.chunk_size)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
......@@ -62,7 +60,6 @@ class ArgumentParser:
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
......@@ -103,6 +100,18 @@ class ArgumentParser:
# local chat
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
# async server
parser.add_argument("--sched_strategy", type=str, default=self.cfg.sched_strategy)
# parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port)
# parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port)
# parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port)
parser.add_argument("--page_size", type=str, default=self.cfg.page_size)
parser.add_argument("--memory_gpu_only", type=str, default=self.cfg.memory_gpu_only)
parser.add_argument("--utilization_percentage", type=str, default=self.cfg.utilization_percentage)
parser.add_argument("--cpu_memory_size_GB", type=str, default=self.cfg.cpu_memory_size_GB)
args = parser.parse_args()
if (args.model_dir is not None or args.model_path is not None):
if (args.model_path is not None):
......@@ -123,6 +132,15 @@ class ArgumentParser:
self.cfg.mount_web = args.web
self.cfg.server_ip = args.host
self.cfg.server_port = args.port
self.cfg.backend_type = args.type
self.cfg.user_force_think = args.force_think
args.gpu_memory_size = args.cache_lens*2*576*61
self.cfg.gpu_memory_size = args.gpu_memory_size
free_ports = get_free_ports(3, [args.port])
args.sched_port = free_ports[0]
args.sched_metrics_port = free_ports[1]
args.kvc2_metrics_port = free_ports[2]
self.cfg.sched_port = free_ports[0]
self.cfg.sched_metrics_port = free_ports[1]
self.cfg.kvc2_metrics_port = free_ports[2]
return args
......@@ -12,18 +12,10 @@ class ConfigArgs(BaseModel):
class Config:
protected_namespaces = ()
paged: bool = Field(None, description="Whether to use paged attention kv cache")
total_context: int = Field(
None,
description=(
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
" total to distribute dynamically over however many jobs are active at once"
),
)
max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
)
chunk_prefill_size: int = Field(
chunk_size: int = Field(
None,
description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
......@@ -70,7 +62,6 @@ class ConfigArgs(BaseModel):
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
......
......@@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
class ThreadContextManager:
lock: Lock
threads_context: Dict[ObjectID, ThreadContext]
......@@ -36,7 +38,16 @@ class ThreadContextManager:
elif isinstance(self.interface, TransformersInterface):
new_context = TransformersThreadContext(run, self.interface)
else:
raise NotImplementedError
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface
if isinstance(self.interface, BalanceServeInterface):
new_context = BalanceServeThreadContext(run, self.interface)
else:
raise NotImplementedError
# elif isinstance(self.interface, BalanceServeInterface):
# new_context = BalanceServeThreadContext(run, self.interface)
# else:
# raise NotImplementedError
self.threads_context[run.thread_id] = new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re = self.threads_context[run.thread_id]
......
from typing import Any, AsyncIterator, List, Optional, Set
from ktransformers.models.custom_cache import KDeepSeekV3Cache
from transformers import (
AutoTokenizer,
AutoConfig,
GenerationConfig,
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from ktransformers.server.config.config import Config
from ..base import ThreadContext, BackendInterfaceBase
import torch
from ktransformers.server.backend.interfaces.transformers import (
ConfigArgs,
default_args,
TextStreamer,
)
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from ktransformers.server.balance_serve.settings import sched_ext
from torch.multiprocessing import Queue
import torch.multiprocessing as mp
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.utils.multi_timer import Profiler
import zmq
import time
import queue
import tempfile
import asyncio
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import os
ktransformer_rules_dir = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
)
default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
}
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
streamer = TextStreamer(tokenizer)
while True:
token = await queue.get()
#print(f"Got token: {token}")
if token is None:
# str = f'{token}\n\n'
# str = model.tokenizer.decode(token)
s = streamer.end()
if s is not None:
yield s
break
# str = model.tokenizer.decode(token)
yield streamer.put(token)
def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
for i in range(generated_tokens.size(0)):
print(generated_tokens[i].item())
query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler):
try:
tokenize_time = profiler.get_timer_sec('tokenize')
prefill_time = profiler.get_timer_sec('prefill')
decode_time = profiler.get_timer_sec('decode')
prefill_count = profiler.get_counter('prefill')
decode_count = profiler.get_counter('decode')
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
except:
logger.info(f'Performance statistics not recorded')
class Engine:
sched_client : SchedulerClient
updates : list[sched_ext.QueryUpdate]
batch : sched_ext.BatchQueryTodo
model_runner: ModelRunner
sampler: Sampler
query_manager: QueryManager
cache: KDeepSeekV3Cache
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None):
self.args = args
# 子进程和父进程无法共享 config 变量
for key, value in vars(args).items():
if value is not None and hasattr(Config(), key):
setattr(Config(), key, value)
self.device = self.args.device
self.sched_client = SchedulerClient(args.sched_port)
self.updates = []
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
self.gen_queue = generated_token_queue
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
with torch.device("meta"):
if config.architectures[0] == "DeepseekV3ForCausalLM":
self.model = KDeepseekV3ForCausalLM(config, self.cache)
elif config.architectures[0] == "DeepseekV2ForCausalLM":
self.model = KDeepseekV2ForCausalLM(config, self.cache)
# print(self.block_num)
context = zmq.Context()
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
# time.sleep(1) # make sure all subscribers are ready
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=True
)
if args.optimize_config_path is None:
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_config_path = args.optimize_config_path
gguf_path = args.gguf_path
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.model.eval()
#@TODO add config
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
self.sampler = Sampler()
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
def sampling(self, forward_output: ForwardBatchOutput):
generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32)
for i in range(forward_output.num_batchs):
logit = forward_output.logits[i]
if hasattr(forward_output, "temperatures"):
temperatures = forward_output.temperatures[i]
else:
temperatures = None
if hasattr(forward_output, "top_ps"):
top_ps = forward_output.top_ps[i]
else:
top_ps = None
sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
generated_tokens, probs=self.sampler(logit, sample_options)
return generated_tokens, probs
def loop(self):
next_batch = None
while True:
self.batch = next_batch
if self.batch is not None:
self.model_runner.run(self.batch, self.query_manager)
if len(self.updates) > 0:
for q in self.updates:
if q.is_prefill == True:
continue
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
try:
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
except queue.Full:
pass#print("Queue is full after timeout; unable to put more items.")
next_batch = self.sched_client.update_last_batch(self.updates)
if next_batch.query_ids == []:
next_batch = None
self.pub_socket.send_pyobj(next_batch)
if next_batch is not None:
self.query_manager.add_query(next_batch)
if self.batch is not None:
self.model_runner.sync()
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
# if self.rank == 0:
generated_tokens, probs = self.sampling( self.model_runner.output)
self.updates = self.query_manager.update(self.batch)
fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
else:
self.updates = []
class BalanceServeThreadContext(ThreadContext):
def get_local_messages(self):
local_messages = []
for m in self.messages:
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
return local_messages
def run_engine(args, token_queue, broadcast_endpoint, event):
engine = Engine(args, token_queue, broadcast_endpoint)
if args.use_cuda_graph:
engine.model_runner.warmup()
event.set()
engine.loop()
class BalanceServeInterface(BackendInterfaceBase):
use_static_cache: bool = True
model: Any
tokenizer: AutoTokenizer
cache: StaticCache
generated_ids: torch.Tensor
seq_length: int
streamer: TextStreamer
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args):
self.args = args
self.queue_map:dict[int,asyncio.Queue] = {}
self.thread_map: dict[int, int] = {}
processes = []
self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
ctx = mp.get_context("spawn")
self.token_queue = ctx.Queue(maxsize=1000)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
self.sched_client = SchedulerClient(args.sched_port)
self.streamer = TextStreamer(self.tokenizer)
start_event = ctx.Event()
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event))
p.start()
processes.append(p)
start_event.wait()
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.queue_proxy())
@asynccontextmanager
async def lifespan(self, app: FastAPI):
asyncio.create_task(self.queue_proxy())
yield
async def queue_proxy(self):
print("Queue Proxy Started")
while True:
try:
query_id, token = self.token_queue.get_nowait()
try:
# query id might not be allocated yet
self.queue_map[query_id].put_nowait(token)
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
except asyncio.QueueFull:
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
await self.queue_map[query_id].put(token)
except queue.Empty:
# print("no new token")
# await asyncio.sleep(1)
await asyncio.sleep(0)
def tokenize_prompt(self, prompt: str):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
return input_ids
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages:
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += '\n' + m["content"]
else:
new_messages.append(m)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
profiler = Profiler()
profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError("local_messages should be List or str")
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist()
query_length = input_ids[0].shape[0]
query_add.query_length = query_length
profiler.set_counter("prefill", query_length)
#@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
query_add.sample_options.temperature = temperature
if top_p == 0:
top_p = 0.0001
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id
is_first_token = True
async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
if is_first_token:
is_first_token=False
profiler.pause_timer("prefill")
profiler.create_and_start_timer("decode")
profiler.set_counter("decode", 0)
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think, None
else:
profiler.inc("decode")
yield token, None
profiler.pause_timer("decode")
report_last_time_performance(profiler)
yield self.streamer.end(), None
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
yield "", "length"
else:
yield "", "stop"
yield RawUsage(
tokenize_time = profiler.get_timer_sec('tokenize'),
prefill_time = profiler.get_timer_sec('prefill'),
decode_time = profiler.get_timer_sec('decode'),
prefill_count = profiler.get_counter('prefill'),
decode_count = profiler.get_counter('decode'),
)
......@@ -211,11 +211,11 @@ class KTransformersInterface(TransformersInterface):
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_prefill_size
chunk_start += self.args.chunk_size
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
......
......@@ -208,6 +208,8 @@ class TransformersInterface(BackendInterfaceBase):
temperature = self.model.generation_config.temperature
if top_p is None:
top_p = self.model.generation_config.top_p
if top_p == 0:
top_p = 0.0001
generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
......@@ -341,7 +343,7 @@ class TransformersInterface(BackendInterfaceBase):
for i in range(1, self.max_new_tokens):
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
......
'''
Date: 2024-11-07 07:30:16
LastEditors: djw
LastEditTime: 2024-11-15 14:23:26
'''
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
import yaml
import json
from typing import Optional
class ModelConfig:
vocab_size: int = 32000
n_layer: int = 1
n_head: int = 32
dim: int = 4096
intermediate_size: int = 18944
n_local_heads: int = 8
head_dim: int = 128
rope_base: float = 1000000.0
norm_eps: float = 1e-06
rope_scaling: Optional[dict] = None
rms_norm_eps: float = 1e-6
hidden_act: str = "silu"
model_path: str
gguf_path: str
optimize_rule_path: str
speculative_rule_path: str
# quantize config
quant_algorithm: Optional[str] = None
quant_group_size: Optional[int] = None
quant_num_bits: Optional[int] = None
json_key_map = {
"vocab_size": "vocab_size",
"n_layer": "num_hidden_layers",
"n_head": "num_attention_heads",
"dim": "hidden_size",
"intermediate_size": "intermediate_size",
"n_local_heads": "num_key_value_heads",
"rope_base": "rope_theta",
"norm_eps": "norm_eps",
"rms_norm_eps": "rms_norm_eps",
"hidden_act": "hidden_act",
}
def __init__(self, config):
self.model_path = config["model"]["model_path"]
self.gguf_path = config["model"]["gguf_path"]
self.optimize_rule_path = config["model"]["optimize_rule_path"]
if "speculative_rule_path" in config["model"]:
self.speculative_rule_path = config["model"]["speculative_rule_path"]
self.speculative_gguf_path = config["model"]["speculative_gguf_path"]
self.speculative_model_path = config["model"]["speculative_model_path"]
self.quant_algorithm = config["model"]["quant"]["algorithm"]
self.quant_group_size = config["model"]["quant"]["group_size"]
self.quant_num_bits = config["model"]["quant"]["num_bits"]
self.load_config()
self.n_layer = config["model"]["n_layers"]
def load_config(self):
config_file = f"{self.model_path}/config.json"
try:
with open(config_file, "r") as f:
config_data = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"Configuration file not found at {config_file}")
for attr, json_key in self.json_key_map.items():
if json_key in config_data:
setattr(self, attr, config_data[json_key])
else:
setattr(self, attr, getattr(self, attr))
class ParallelConfig:
def __init__(
self,
config,
) -> None:
self.pipeline_parallel_size = config["parallel"]["pp"]
self.tensor_parallel_size = config["parallel"]["tp"]
self.disable_custom_all_reduce = config["parallel"]["disable_custom_all_reduce"]
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
class AttnConfig:
page_size: int = 256
block_num: int = 32
max_batch_token : int = 256
max_batch_size: int = 32
def __init__(self, config):
self.page_size = config["attn"]["page_size"]
self.block_num = config["attn"]["block_num"]
self.max_batch_token = config["attn"]["max_batch_token"]
self.max_batch_size = config["attn"]["max_batch_size"]
class SamplerConfig():
# Batched sampling params
temperatures: float
is_all_greedy: bool
def __init__(self, config):
self.temperatures = config["sample"]["temperature"]
self.is_all_greedy = True
def load_yaml_config(file_path):
with open(file_path, "r") as f:
return yaml.safe_load(f)
class LLMConfig:
model_config: ModelConfig
parallel_config: ParallelConfig
attn_config: AttnConfig
sample_config: SamplerConfig
config_file: str
def __init__(self, config_file):
self.config_file = config_file
config = load_yaml_config(config_file)
self.model_config = ModelConfig(config)
self.parallel_config = ParallelConfig(config)
self.attn_config = AttnConfig(config)
self.sample_config = SamplerConfig(config)
from .communication_op import *
from .parallel_state import *
from .utils import *
"""
Date: 2024-12-11 06:02:42
LastEditors: djw
LastEditTime: 2024-12-12 09:52:06
"""
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed
from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)
def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
def tensor_model_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(lib_name), \
f"Unexpected filename: {filename} for library {lib_name}"
return path
class CudaRTLibrary:
exported_functions = [
# ​cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# ​cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
# ​cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function("cudaIpcOpenMemHandle", cudaError_t, [
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
assert so_file is not None, \
"libcudart is not loaded in the current process"
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
count: int) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self,
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
ctypes.byref(handle), devPtr))
return handle
def cudaIpcOpenMemHandle(self,
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
return devPtr
import ctypes
from contextlib import contextmanager
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import server.envs as envs
from server.inference.distributed.cuda_wrapper import CudaRTLibrary
from server.inference.distributed.custom_all_reduce_utils import gpu_p2p_access_check
from server.inference.distributed.parallel_state import in_the_same_node_as
from server.inference.platforms import current_platform
from server.utils import cuda_device_count_stateless
import vLLMCustomAllreduce
try:
vLLMCustomAllreduce.meta_size()
custom_ar = True
except Exception:
# For AMD GPUs and CPUs
custom_ar = False
def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
continue
if envs.VLLM_SKIP_P2P_CHECK:
print("Skipping P2P check and trusting the driver's P2P report.")
return torch.cuda.can_device_access_peer(rank, i)
if not gpu_p2p_access_check(rank, i):
return False
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
# max_size: max supported allreduce size
def __init__(
self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 1024,
) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
self.group = group
assert (
dist.get_backend(group) != dist.Backend.NCCL
), "CustomAllreduce should be attached to a non-NCCL group."
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
print(
"Custom allreduce is disabled because this process group"
" spans across nodes."
)
return
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
print(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size,
str(CustomAllreduce._SUPPORTED_WORLD_SIZES),
)
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
assert current_platform.is_cuda()
from server.inference.platforms.cuda import CudaPlatform
cuda_platform: CudaPlatform = current_platform
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
print(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
print(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(
vLLMCustomAllreduce.meta_size() + max_size, group=group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = full_nvlink
self._ptr = vLLMCustomAllreduce.init_custom_ar(
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
)
vLLMCustomAllreduce.register_buffer(self._ptr, self.buffer_ptrs)
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = vLLMCustomAllreduce.get_graph_buffer_ipc_meta(self._ptr)
print("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
vLLMCustomAllreduce.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
def all_reduce(
self, inp: torch.Tensor, *, out: torch.Tensor = None, bsz_tensor: torch.Tensor = None, registered: bool = False,
is_compute_bound=False, overlap=False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if is_compute_bound:
sms = 2 if overlap else 36
else:
sms = 20 if overlap else 36
#print("all reduce sms", sms)
if out is None:
out = torch.empty_like(inp)
if registered:
vLLMCustomAllreduce.all_reduce(self._ptr, inp, out, 0, 0, bsz_tensor, block_limit=sms)
else:
vLLMCustomAllreduce.all_reduce(
self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size, bsz_tensor, block_limit=sms
)
return out
def custom_all_reduce(self, input: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=True, is_compute_bound=is_compute_bound, overlap=overlap)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=False, is_compute_bound=is_compute_bound, overlap=overlap)
def close(self):
if not self.disabled and self._ptr:
vLLMCustomAllreduce.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
def __del__(self):
self.close()
import ctypes
import json
import os
import pickle
import subprocess
import sys
import tempfile
from itertools import product
from typing import Dict, List, Optional, Sequence
import torch.distributed as dist
import torch.multiprocessing as mp
import server.envs as envs
from server.inference.distributed.cuda_wrapper import CudaRTLibrary
from server.utils import cuda_device_count_stateless, update_environment_variables
def producer(
batch_src: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for i in batch_src:
lib.cudaSetDevice(i)
pointer = lib.cudaMalloc(1024)
lib.cudaMemset(pointer, 1, 1024)
lib.cudaDeviceSynchronize()
handle = lib.cudaIpcGetMemHandle(pointer)
producer_queue.put(handle)
open_success = consumer_queue.get()
if open_success:
# use two queues to simulate barrier
producer_queue.put(0)
consumer_queue.get()
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def consumer(
batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None,
):
if cuda_visible_devices is not None:
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
lib = CudaRTLibrary()
for j in batch_tgt:
lib.cudaSetDevice(j)
handle = producer_queue.get()
open_success = False
try:
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
open_success = True
except RuntimeError:
# cannot error out here, because the producer process
# is still waiting for the response.
pass
consumer_queue.put(open_success)
if open_success:
# modify the memory
lib.cudaMemset(pointer, 2, 1024)
lib.cudaDeviceSynchronize()
# use two queues to simulate barrier
producer_queue.get()
consumer_queue.put(0)
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def can_actually_p2p(
batch_src: Sequence[int],
batch_tgt: Sequence[int],
) -> Sequence[bool]:
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
Therefore, we have to perform a real P2P access to check if it is actually
possible.
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
GPU src --> cuda context src --> tensor src --> process src
We need to combine p2p and cuda IPC, so that:
GPU src --> cuda context src --> tensor src --> process src
|shared|
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
That is to say, process src creates a tensor in GPU src, passes IPC handle to
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
tensor in process tgt will be reflected in the tensor in process src, because
they are the same memory segment.
It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU src. That's why we need p2p access.
The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# make sure the processes are spawned
smp = mp.get_context("spawn")
producer_queue = smp.Queue()
consumer_queue = smp.Queue()
result_queue = smp.Queue()
p_src = smp.Process(
target=producer,
args=(
batch_src,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_tgt = smp.Process(
target=consumer,
args=(
batch_tgt,
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices,
),
)
p_src.start()
p_tgt.start()
p_src.join()
p_tgt.join()
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
result: List[bool] = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
if a != b:
print(
"Two processes do not agree on the P2P access"
" status on %d -> %d, treat as disabled.",
src,
tgt,
)
result.append(False)
else:
result.append(a)
return result
# why do we need this cache?
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
# if we test it every time, it will be very slow, because we need to create
# N * N * 2 processes, where N is the world size. This is very slow.
# to reduce the time, we use a cache file to store the p2p access status.
# the cache file is generated by the master process if it does not exist.
# then all the processes can read the cache file to check the p2p access status.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
"""Check if GPU src can access GPU tgt."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
is_distributed = dist.is_initialized()
num_dev = cuda_device_count_stateless()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.join(
envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
)
os.makedirs(os.path.dirname(path), exist_ok=True)
from server.inference.distributed.parallel_state import get_world_group
if (not is_distributed or get_world_group().local_rank == 0) and (
not os.path.exists(path)
):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
print("generating GPU P2P access cache in %s", path)
cache: Dict[str, bool] = {}
ids = list(range(num_dev))
# batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids)))
# NOTE: we use `subprocess` rather than `multiprocessing` here
# because the caller might not have `if __name__ == "__main__":`,
# in that case we cannot use spawn method in multiprocessing.
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
returned = subprocess.run(
[sys.executable, __file__], input=input_bytes, capture_output=True
)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}"
) from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
get_world_group().barrier()
print("reading GPU P2P access cache from %s", path)
with open(path) as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
__all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.
If you only need to use the distributed environment without model/pipeline
parallelism, you can skip the model parallel initialization and destruction
steps.
"""
import contextlib
import gc
import pickle
import weakref
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
import server.envs as envs
from server.inference.platforms import current_platform
from server.utils import direct_register_custom_op, supports_custom_op
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size()))
)
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {}
def _get_unique_name(name: str) -> str:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if name not in _group_name_counter:
_group_name_counter[name] = 0
newname = f"{name}:{_group_name_counter[name]}"
_group_name_counter[name] += 1
return newname
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
def _register_group(group: "GroupCoordinator") -> None:
_groups[group.unique_name] = weakref.ref(group)
if supports_custom_op():
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce_in_place(tensor)
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return
direct_register_custom_op(
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
fake_impl=inplace_all_reduce_fake,
)
def outplace_all_reduce(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)
def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor:
return torch.empty_like(tensor)
direct_register_custom_op(
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
fake_impl=outplace_all_reduce_fake,
)
class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
a specific implementation (e.g. switch allreduce implementation
based on the tensor size and cuda graph mode).
"""
# available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
# Process | Node | Rank | Local Rank | Rank in Group
# 0 | 0 | 0 | 0 | 0
# 1 | 0 | 1 | 1 | 1
# 2 | 1 | 2 | 0 | 2
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
assert self.cpu_group is not None
assert self.device_group is not None
assert current_platform.is_cuda_alike()
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
# lazy import to avoid documentation build error
from server.inference.distributed.custom_all_reduce import CustomAllreduce
from server.inference.distributed.pynccl import PyNcclCommunicator
self.pynccl_comm: Optional[PyNcclCommunicator] = None
# if use_pynccl and self.world_size > 1:
# self.pynccl_comm = PyNcclCommunicator(
# group=self.cpu_group,
# device=self.device,
# )
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
#### we assume we won't use tpu or hpu or xpu or messagequeue broadcast
# from vllm.distributed.device_communicators.tpu_communicator import (
# TpuCommunicator)
# self.tpu_communicator: Optional[TpuCommunicator] = None
# if use_tpu_communicator and self.world_size > 1:
# self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
self.tpu_communicator = None
# from vllm.distributed.device_communicators.hpu_communicator import (
# HpuCommunicator)
# self.hpu_communicator: Optional[HpuCommunicator]
# if use_hpu_communicator and self.world_size > 1:
# self.hpu_communicator = HpuCommunicator(group=self.device_group)
self.hpu_communicator = None
# from vllm.distributed.device_communicators.xpu_communicator import (
# XpuCommunicator)
# self.xpu_communicator: Optional[XpuCommunicator]
# if use_xpu_communicator and self.world_size > 1:
# self.xpu_communicator = XpuCommunicator(group=self.device_group)
self.xpu_communicator = None
# from vllm.distributed.device_communicators.shm_broadcast import (
# MessageQueue)
# self.mq_broadcaster: Optional[MessageQueue] = None
# if use_message_queue_broadcaster and self.world_size > 1:
# self.mq_broadcaster = MessageQueue.create_from_process_group(
# self.cpu_group, 1 << 22, 6)
self.mq_broadcaster = None
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
return self.ranks[0]
@property
def last_rank(self):
"""Return the global rank of the last process in the group"""
return self.ranks[-1]
@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank
@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank
@property
def next_rank(self):
"""Return the global rank of the process that follows the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group + 1) % world_size]
@property
def prev_rank(self):
"""Return the global rank of the process that precedes the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group - 1) % world_size]
@contextmanager
def graph_capture(
self, graph_capture_context: Optional[GraphCaptureContext] = None
):
if graph_capture_context is None:
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
ca_comm = self.ca_comm
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the
# tensor size is too large, it will fallback to the next
# available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using
# CUDA graph, we use either custom all-reduce kernel or
# PyTorch NCCL. We always prioritize using custom all-reduce
# kernel but fall back to PyTorch or pynccl if it is
# disabled or not supported.
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
)
with maybe_pynccl_context:
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
if not supports_custom_op():
self._all_reduce_in_place(input_)
return input_
if self.tpu_communicator is not None and not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self.tpu_communicator.all_reduce(input_)
if self.hpu_communicator is not None and not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)
if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)
if (
self.ca_comm is not None
and not self.ca_comm.disabled
and self.ca_comm.should_custom_ar(input_)
):
return torch.ops.vllm.outplace_all_reduce(
input_, group_name=self.unique_name, bsz_tensor=bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap
)
else:
#assert self.ca_comm is not None
#assert not self.ca_comm.disabled
#assert self.ca_comm.should_custom_ar(input_)
torch.ops.vllm.inplace_all_reduce(input_, group_name=self.unique_name)
return input_
def _all_reduce_out_place(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:
ca_comm = self.ca_comm
assert ca_comm is not None
assert not ca_comm.disabled
out = ca_comm.custom_all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)
assert out is not None
return out
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
# Reshape
output_tensor = output_tensor.reshape((world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
)
return output_tensor
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim)
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(
input_, src=self.ranks[src], group=self.device_group
)
return input_
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj
if self.mq_broadcaster is not None:
assert src == 0, "Message queue broadcaster only supports src=0"
return self.mq_broadcaster.broadcast_object(obj)
if self.rank_in_group == src:
torch.distributed.broadcast_object_list(
[obj], src=self.ranks[src], group=self.cpu_group
)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=self.ranks[src], group=self.cpu_group
)
return recv[0]
def broadcast_object_list(
self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None
):
"""Broadcast the input object list.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(
obj_list, src=self.ranks[src], group=self.device_group
)
return obj_list
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor(
[object_tensor.numel()], dtype=torch.long, device="cpu"
)
# Send object size
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
# Send object
torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
return None
def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert (
src != self.rank_in_group
), "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(
size_tensor, src=self.ranks[src], group=self.cpu_group
)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu",
)
rank_object = torch.distributed.recv(
object_tensor, src=self.ranks[src], group=self.cpu_group
)
assert (
rank_object == rank_size
), "Received object sender rank does not match the size sender rank."
obj = pickle.loads(object_tensor.numpy().tobytes())
return obj
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor, src=self.ranks[src], group=metadata_group, async_op=True
)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True,
)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group
)
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(
tensor, dst=self.ranks[dst], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group
)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (
all_gather_group is not None
and tensor.numel() % all_gather_size == 0
)
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(
tensor, src=self.ranks[src], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
"""
torch.distributed.barrier(group=self.cpu_group)
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
_WORLD: Optional[GroupCoordinator] = None
def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, "world group is not initialized"
return _WORLD
def init_world_group(
ranks: List[int], local_rank: int, backend: str
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
group_name="world",
)
def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
_TP: Optional[GroupCoordinator] = None
def get_tp_group() -> GroupCoordinator:
assert _TP is not None, "tensor model parallel group is not initialized"
return _TP
# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group
_PP: Optional[GroupCoordinator] = None
def get_pp_group() -> GroupCoordinator:
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(
context
):
yield context
_ENABLE_CUSTOM_ALL_REDUCE = True
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
):
print(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
world_size,
rank,
local_rank,
distributed_init_method,
backend,
)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment"
)
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
else:
assert (
_WORLD.world_size == torch.distributed.get_world_size()
), "world group already initialized with a different world size"
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
global _TP
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(
range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
)
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="tp",
)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False,
group_name="pp",
)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size, pipeline_model_parallel_size, backend
)
return
assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
"tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}"
)
pp_world_size = get_pp_group().world_size
assert pp_world_size == pipeline_model_parallel_size, (
"pipeline parallel group already initialized, but of unexpected size: "
f"{pp_world_size=} vs. "
f"{pipeline_model_parallel_size=}"
)
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return _TP is not None and _PP is not None
_TP_STATE_PATCHED = False
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
_TP_STATE_PATCHED = True
old_tp_group = get_tp_group()
global _TP
_TP = tp_group
try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
if _TP:
_TP.destroy()
_TP = None
global _PP
if _PP:
_PP.destroy()
_PP = None
def destroy_distributed_environment():
global _WORLD
if _WORLD:
_WORLD.destroy()
_WORLD = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
if not current_platform.is_cpu():
torch.cuda.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
assert (
torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL
), "in_the_same_node_as should be tested with a non-NCCL group."
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
magic_message = b"magic_message"
shm = None
try:
with contextlib.suppress(OSError):
if rank == source_rank:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[: len(magic_message)] = magic_message
torch.distributed.broadcast_object_list(
[shm.name], src=ranks[source_rank], group=pg
)
is_in_the_same_node[rank] = 1
else:
# try to open the shared memory segment
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=ranks[source_rank], group=pg
)
name = recv[0]
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch(
"multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None,
):
shm = shared_memory.SharedMemory(name=name)
if shm.buf[: len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1
except Exception as e:
print("Error ignored in is_in_the_same_node: %s", e)
finally:
if shm:
shm.close()
torch.distributed.barrier(group=pg)
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == source_rank and shm:
shm.unlink()
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return [x == 1 for x in is_in_the_same_node.tolist()]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment