Unverified Commit 8275049c authored by Zhang, Liangang's avatar Zhang, Liangang Committed by GitHub
Browse files

Add device support (#1607)

parent 5476ccad
...@@ -423,6 +423,9 @@ class ScheduleBatch: ...@@ -423,6 +423,9 @@ class ScheduleBatch:
# Stream # Stream
has_stream: bool = False has_stream: bool = False
# device
device: str = "cuda"
# Has regex # Has regex
has_regex: bool = False has_regex: bool = False
...@@ -439,6 +442,7 @@ class ScheduleBatch: ...@@ -439,6 +442,7 @@ class ScheduleBatch:
tree_cache=tree_cache, tree_cache=tree_cache,
return_logprob=return_logprob, return_logprob=return_logprob,
has_stream=has_stream, has_stream=has_stream,
device=req_to_token_pool.device,
has_regex=has_regex, has_regex=has_regex,
) )
......
...@@ -81,10 +81,11 @@ class ModelRunner: ...@@ -81,10 +81,11 @@ class ModelRunner:
# Parse args # Parse args
self.model_config = model_config self.model_config = model_config
self.mem_fraction_static = mem_fraction_static self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.nccl_port = nccl_port self.dist_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_multimodal_model = is_multimodal_model( self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures self.model_config.hf_config.architectures
...@@ -132,39 +133,45 @@ class ModelRunner: ...@@ -132,39 +133,45 @@ class ModelRunner:
server_args.max_running_requests, server_args.max_running_requests,
server_args.max_total_tokens, server_args.max_total_tokens,
) )
self.init_cublas() if self.device == "cuda":
self.init_attention_backend() self.init_cublas()
self.init_cuda_graphs() self.init_attention_backend()
self.init_cuda_graphs()
else:
self.init_attention_backend()
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
# Init torch distributed # Init torch distributed
torch.cuda.set_device(self.gpu_id) if self.device == "cuda":
logger.info("Init nccl begin.") torch.cuda.set_device(self.gpu_id)
backend = "nccl"
if not self.server_args.enable_p2p_check: if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id) monkey_patch_vllm_p2p_access_check(self.gpu_id)
if self.server_args.dist_init_addr: if self.server_args.dist_init_addr:
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}" dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
else: else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment( init_distributed_environment(
backend="nccl", backend=backend,
world_size=self.tp_size, world_size=self.tp_size,
rank=self.tp_rank, rank=self.tp_rank,
local_rank=self.gpu_id, local_rank=self.gpu_id,
distributed_init_method=nccl_init_method, distributed_init_method=dist_init_method,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
min_per_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1 self.device, self.gpu_id, distributed=self.tp_size > 1
) )
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph, # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph. # so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)): if self.device == "cuda" and not all(
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
):
self.server_args.disable_cuda_graph_padding = True self.server_args.disable_cuda_graph_padding = True
logger.info( logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism." "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
...@@ -172,7 +179,7 @@ class ModelRunner: ...@@ -172,7 +179,7 @@ class ModelRunner:
# Check memory for tensor parallelism # Check memory for tensor parallelism
if self.tp_size > 1: if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.gpu_id) local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9: if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError( raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes." "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
...@@ -182,23 +189,22 @@ class ModelRunner: ...@@ -182,23 +189,22 @@ class ModelRunner:
def load_model(self): def load_model(self):
logger.info( logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
# This can reduce thread conflicts and speed up weight loading. # This can reduce thread conflicts and speed up weight loading.
torch.set_num_threads(1) torch.set_num_threads(1)
if self.device == "cuda":
if torch.cuda.get_device_capability()[0] < 8: if torch.cuda.get_device_capability()[0] < 8:
logger.info( logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support." "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
) )
self.server_args.dtype = "float16" self.server_args.dtype = "float16"
if torch.cuda.get_device_capability()[1] < 5: if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.") raise RuntimeError("SGLang only supports sm75 and above.")
# Prepare the vllm model config # Prepare the vllm model config
monkey_patch_vllm_dummy_weight_loader() monkey_patch_vllm_dummy_weight_loader()
self.device_config = DeviceConfig()
self.load_config = LoadConfig(load_format=self.server_args.load_format) self.load_config = LoadConfig(load_format=self.server_args.load_format)
self.vllm_model_config = VllmModelConfig( self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path, model=self.server_args.model_path,
...@@ -220,7 +226,7 @@ class ModelRunner: ...@@ -220,7 +226,7 @@ class ModelRunner:
self.model = get_model( self.model = get_model(
model_config=self.vllm_model_config, model_config=self.vllm_model_config,
load_config=self.load_config, load_config=self.load_config,
device_config=self.device_config, device_config=DeviceConfig(self.device),
parallel_config=None, parallel_config=None,
scheduler_config=None, scheduler_config=None,
lora_config=None, lora_config=None,
...@@ -240,7 +246,7 @@ class ModelRunner: ...@@ -240,7 +246,7 @@ class ModelRunner:
f"Load weight end. " f"Load weight end. "
f"type={type(self.model).__name__}, " f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, " f"dtype={self.dtype}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
def update_weights(self, model_path: str, load_format: str): def update_weights(self, model_path: str, load_format: str):
...@@ -254,10 +260,10 @@ class ModelRunner: ...@@ -254,10 +260,10 @@ class ModelRunner:
logger.info( logger.info(
f"Update weights begin. " f"Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
target_device = torch.device(self.device_config.device) target_device = torch.device(self.device)
try: try:
# TODO: Use a better method to check this # TODO: Use a better method to check this
...@@ -343,7 +349,7 @@ class ModelRunner: ...@@ -343,7 +349,7 @@ class ModelRunner:
def profile_max_num_token(self, total_gpu_memory: int): def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1 self.device, self.gpu_id, distributed=self.tp_size > 1
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
...@@ -409,11 +415,10 @@ class ModelRunner: ...@@ -409,11 +415,10 @@ class ModelRunner:
4096, 4096,
) )
device = "cuda"
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1, size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4, max_context_len=self.model_config.context_len + 4,
device=device, device=self.device,
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
...@@ -425,7 +430,7 @@ class ModelRunner: ...@@ -425,7 +430,7 @@ class ModelRunner:
kv_lora_rank=self.model_config.kv_lora_rank, kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=device, device=self.device,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
...@@ -434,11 +439,11 @@ class ModelRunner: ...@@ -434,11 +439,11 @@ class ModelRunner:
head_num=self.model_config.get_num_kv_heads(self.tp_size), head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=device, device=self.device,
) )
logger.info( logger.info(
f"Memory pool end. " f"Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
def init_cublas(self): def init_cublas(self):
......
...@@ -37,6 +37,9 @@ class SamplingBatchInfo: ...@@ -37,6 +37,9 @@ class SamplingBatchInfo:
linear_penalties: torch.Tensor = None linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None scaling_penalties: torch.Tensor = None
# Device
device: str = "cuda"
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs reqs = batch.reqs
...@@ -62,6 +65,7 @@ class SamplingBatchInfo: ...@@ -62,6 +65,7 @@ class SamplingBatchInfo:
min_ps=min_ps, min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size, vocab_size=vocab_size,
device=batch.input_ids.device,
) )
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
...@@ -75,7 +79,7 @@ class SamplingBatchInfo: ...@@ -75,7 +79,7 @@ class SamplingBatchInfo:
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator( ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size, vocab_size=vocab_size,
batch=batch, batch=batch,
device="cuda", device=batch.input_ids.device,
Penalizers={ Penalizers={
penaltylib.BatchedFrequencyPenalizer, penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer, penaltylib.BatchedMinNewTokensPenalizer,
...@@ -107,7 +111,7 @@ class SamplingBatchInfo: ...@@ -107,7 +111,7 @@ class SamplingBatchInfo:
self.linear_penalties = torch.zeros( self.linear_penalties = torch.zeros(
(bs, self.vocab_size), (bs, self.vocab_size),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device=self.device,
) )
self.linear_penalties = penalizer.apply(self.linear_penalties) self.linear_penalties = penalizer.apply(self.linear_penalties)
...@@ -119,7 +123,10 @@ class SamplingBatchInfo: ...@@ -119,7 +123,10 @@ class SamplingBatchInfo:
if has_regex: if has_regex:
self.vocab_mask = torch.zeros( self.vocab_mask = torch.zeros(
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" len(self.temperatures),
self.vocab_size,
dtype=torch.bool,
device=self.device,
) )
for i, regex_fsm in enumerate(self.regex_fsms): for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None: if regex_fsm is not None:
...@@ -144,7 +151,12 @@ class SamplingBatchInfo: ...@@ -144,7 +151,12 @@ class SamplingBatchInfo:
@staticmethod @staticmethod
def merge_bias_tensor( def merge_bias_tensor(
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0 lhs: torch.Tensor,
rhs: torch.Tensor,
bs1: int,
bs2: int,
device: str,
default: int = 0,
): ):
# bias tensor can be None # bias tensor can be None
if lhs is not None or rhs is not None: if lhs is not None or rhs is not None:
...@@ -155,9 +167,9 @@ class SamplingBatchInfo: ...@@ -155,9 +167,9 @@ class SamplingBatchInfo:
shape, dtype = rhs.shape[1:], rhs.dtype shape, dtype = rhs.shape[1:], rhs.dtype
with torch.dtype(dtype): with torch.dtype(dtype):
if lhs is None: if lhs is None:
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default) lhs = torch.empty((bs1, *shape), device=device).fill_(default)
if rhs is None: if rhs is None:
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default) rhs = torch.empty((bs2, *shape), device=device).fill_(default)
return torch.cat([lhs, rhs]) return torch.cat([lhs, rhs])
return None return None
...@@ -176,5 +188,5 @@ class SamplingBatchInfo: ...@@ -176,5 +188,5 @@ class SamplingBatchInfo:
setattr(self, item, torch.concat([self_val, other_val])) setattr(self, item, torch.concat([self_val, other_val]))
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other) self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
...@@ -36,6 +36,7 @@ class ServerArgs: ...@@ -36,6 +36,7 @@ class ServerArgs:
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
dtype: str = "auto" dtype: str = "auto"
device: str = "cuda"
kv_cache_dtype: str = "auto" kv_cache_dtype: str = "auto"
trust_remote_code: bool = True trust_remote_code: bool = True
context_length: Optional[int] = None context_length: Optional[int] = None
...@@ -237,6 +238,13 @@ class ServerArgs: ...@@ -237,6 +238,13 @@ class ServerArgs:
'* "float" is shorthand for FP32 precision.\n' '* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.', '* "float32" for FP32 precision.',
) )
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help="The device type.",
)
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
......
...@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0): ...@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper return wrapper
def get_available_gpu_memory(gpu_id, distributed=False): def get_available_gpu_memory(device, gpu_id, distributed=False):
""" """
Get available memory for cuda:gpu_id device. Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs. When distributed is True, the available memory is the minimum available memory of all GPUs.
""" """
num_gpus = torch.cuda.device_count() if device == "cuda":
assert gpu_id < num_gpus num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
if torch.cuda.current_device() != gpu_id: torch.cuda.empty_cache()
print( free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)
torch.cuda.empty_cache() elif device == "xpu":
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) num_gpus = torch.xpu.device_count()
assert gpu_id < num_gpus
if torch.xpu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
"which may cause useless memory allocation for torch XPU context.",
)
torch.xpu.empty_cache()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
free_gpu_memory = total_gpu_memory - used_memory
if distributed: if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device("cuda", gpu_id) torch.device(device, gpu_id)
) )
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
free_gpu_memory = tensor.item() free_gpu_memory = tensor.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment