Unverified Commit 8275049c authored by Zhang, Liangang's avatar Zhang, Liangang Committed by GitHub
Browse files

Add device support (#1607)

parent 5476ccad
......@@ -423,6 +423,9 @@ class ScheduleBatch:
# Stream
has_stream: bool = False
# device
device: str = "cuda"
# Has regex
has_regex: bool = False
......@@ -439,6 +442,7 @@ class ScheduleBatch:
tree_cache=tree_cache,
return_logprob=return_logprob,
has_stream=has_stream,
device=req_to_token_pool.device,
has_regex=has_regex,
)
......
......@@ -81,10 +81,11 @@ class ModelRunner:
# Parse args
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = tp_size
self.nccl_port = nccl_port
self.dist_port = nccl_port
self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
......@@ -132,39 +133,45 @@ class ModelRunner:
server_args.max_running_requests,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_attention_backend()
self.init_cuda_graphs()
if self.device == "cuda":
self.init_cublas()
self.init_attention_backend()
self.init_cuda_graphs()
else:
self.init_attention_backend()
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
# Init torch distributed
torch.cuda.set_device(self.gpu_id)
logger.info("Init nccl begin.")
if self.device == "cuda":
torch.cuda.set_device(self.gpu_id)
backend = "nccl"
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
if self.server_args.dist_init_addr:
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend="nccl",
backend=backend,
world_size=self.tp_size,
rank=self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=nccl_init_method,
distributed_init_method=dist_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
self.device, self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
if self.device == "cuda" and not all(
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
......@@ -172,7 +179,7 @@ class ModelRunner:
# Check memory for tensor parallelism
if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
......@@ -182,23 +189,22 @@ class ModelRunner:
def load_model(self):
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
# This can reduce thread conflicts and speed up weight loading.
torch.set_num_threads(1)
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
if self.device == "cuda":
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
# Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader()
self.device_config = DeviceConfig()
self.load_config = LoadConfig(load_format=self.server_args.load_format)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
......@@ -220,7 +226,7 @@ class ModelRunner:
self.model = get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=self.device_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
......@@ -240,7 +246,7 @@ class ModelRunner:
f"Load weight end. "
f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def update_weights(self, model_path: str, load_format: str):
......@@ -254,10 +260,10 @@ class ModelRunner:
logger.info(
f"Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
target_device = torch.device(self.device_config.device)
target_device = torch.device(self.device)
try:
# TODO: Use a better method to check this
......@@ -343,7 +349,7 @@ class ModelRunner:
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
self.device, self.gpu_id, distributed=self.tp_size > 1
)
if (
self.model_config.attention_arch == AttentionArch.MLA
......@@ -409,11 +415,10 @@ class ModelRunner:
4096,
)
device = "cuda"
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=device,
device=self.device,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
......@@ -425,7 +430,7 @@ class ModelRunner:
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
device=device,
device=self.device,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
......@@ -434,11 +439,11 @@ class ModelRunner:
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=device,
device=self.device,
)
logger.info(
f"Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def init_cublas(self):
......
......@@ -37,6 +37,9 @@ class SamplingBatchInfo:
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
# Device
device: str = "cuda"
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
......@@ -62,6 +65,7 @@ class SamplingBatchInfo:
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
device=batch.input_ids.device,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
......@@ -75,7 +79,7 @@ class SamplingBatchInfo:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device="cuda",
device=batch.input_ids.device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
......@@ -107,7 +111,7 @@ class SamplingBatchInfo:
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device="cuda",
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
......@@ -119,7 +123,10 @@ class SamplingBatchInfo:
if has_regex:
self.vocab_mask = torch.zeros(
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
len(self.temperatures),
self.vocab_size,
dtype=torch.bool,
device=self.device,
)
for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None:
......@@ -144,7 +151,12 @@ class SamplingBatchInfo:
@staticmethod
def merge_bias_tensor(
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
lhs: torch.Tensor,
rhs: torch.Tensor,
bs1: int,
bs2: int,
device: str,
default: int = 0,
):
# bias tensor can be None
if lhs is not None or rhs is not None:
......@@ -155,9 +167,9 @@ class SamplingBatchInfo:
shape, dtype = rhs.shape[1:], rhs.dtype
with torch.dtype(dtype):
if lhs is None:
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
return torch.cat([lhs, rhs])
return None
......@@ -176,5 +188,5 @@ class SamplingBatchInfo:
setattr(self, item, torch.concat([self_val, other_val]))
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other)
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
......@@ -36,6 +36,7 @@ class ServerArgs:
skip_tokenizer_init: bool = False
load_format: str = "auto"
dtype: str = "auto"
device: str = "cuda"
kv_cache_dtype: str = "auto"
trust_remote_code: bool = True
context_length: Optional[int] = None
......@@ -237,6 +238,13 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.',
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help="The device type.",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
......
......@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper
def get_available_gpu_memory(gpu_id, distributed=False):
def get_available_gpu_memory(device, gpu_id, distributed=False):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
"""
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if device == "cuda":
num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
torch.cuda.empty_cache()
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
torch.cuda.empty_cache()
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
elif device == "xpu":
num_gpus = torch.xpu.device_count()
assert gpu_id < num_gpus
if torch.xpu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
"which may cause useless memory allocation for torch XPU context.",
)
torch.xpu.empty_cache()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id)
torch.device(device, gpu_id)
)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment