Commit f44e9f9e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev' into v0.7.2-fusion

parents 525d9d7e 8fc15e04
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
import numa,os
# 设置当前进程绑定到 NUMA 节点
def bind_to_numa(local_rank):
env_str = f"VLLM_RANK{local_rank}_NUMA"
node_count = numa.get_max_node() + 1
numa_node = int(os.getenv(env_str, -1))
# 未配置环境变量或配置错误则不做绑定,TODO:根据topo自动绑定方案
if numa_node < 0:
logger.warning("%s is unset or set incorrectly, vllm will not bind to numa! %s = %d", env_str, env_str, numa_node)
return
if numa_node > numa.get_max_node():
raise ValueError(f"NUMA node {numa_node} is not available.")
numa.bind([numa_node])
logger = init_logger(__name__)
def create_worker(**kwargs):
vllm_config = kwargs.get("vllm_config")
VLLM_NUMA_BIND = int(os.getenv("VLLM_NUMA_BIND", 1))
if VLLM_NUMA_BIND > 0:
# 绑定当前进程到指定 NUMA 节点
bind_to_numa(kwargs['local_rank'])
pid = os.getpid()
logger.info("########## %d process(rank%s) is running on CPU(s): %s", pid, str(kwargs['local_rank']), str(os.sched_getaffinity(pid)))
logger.info("########## %d process(rank%s) is running on memnode(s): %s", pid, str(kwargs['local_rank']), str(numa.get_membind()))
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
wrapper.init_worker(**kwargs)
return wrapper.worker
class GPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
return create_worker(**self._get_worker_kwargs(
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method))
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
def start_profile(self) -> None:
self.driver_worker.start_profile()
def stop_profile(self) -> None:
self.driver_worker.stop_profile()
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req)
return output
...@@ -42,7 +42,10 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren ...@@ -42,7 +42,10 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15 {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32 {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 4}, #256
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 4},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 8},#8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "kpack": 1,"num_stages": 0,"num_warps": 8}
] ]
stage2_best_config=[ stage2_best_config=[
...@@ -62,7 +65,11 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren ...@@ -62,7 +65,11 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13 {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15 {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #16 {"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4} ,#256
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4},# 8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}
] ]
else: else:
stage1_best_config=[ stage1_best_config=[
...@@ -83,7 +90,10 @@ else: ...@@ -83,7 +90,10 @@ else:
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2}, #14 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 2}, #15 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #32 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#256
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},
] ]
stage2_best_config=[ stage2_best_config=[
...@@ -103,7 +113,11 @@ else: ...@@ -103,7 +113,11 @@ else:
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2},#13 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #14 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #15 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #16 {"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #32
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4}, #256
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0, "num_warps": 4}, #1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0, "num_warps": 4}, #8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4}
] ]
@triton.jit @triton.jit
...@@ -1662,8 +1676,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1662,8 +1676,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# so the cache size and config are already set correctly and # so the cache size and config are already set correctly and
# do not need to be adjusted. # do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk_ids.shape[1]]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
if not use_int8_w8a8:
config = get_config_func(tokens_in_chunk) config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
...@@ -1677,24 +1692,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1677,24 +1692,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config =stage1_best_config[15] config =stage1_best_config[15]
elif m<=64: elif m<=64:
config =stage1_best_config[16] config =stage1_best_config[16]
elif m<256: elif m<=256:
config ={ config =stage1_best_config[17]
"BLOCK_SIZE_M": 16, elif m<=1024:
"BLOCK_SIZE_N": 32, config =stage1_best_config[18]
"BLOCK_SIZE_K": 64, elif m<=8192:
"GROUP_SIZE_M": 1, config =stage1_best_config[19]
"num_stages": 0,
"num_warps": 4
}
else: else:
config ={ config =stage1_best_config[20]
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
if moe_ep_size == 1: if moe_ep_size == 1:
if use_int4_w4a16: if use_int4_w4a16:
...@@ -1740,24 +1745,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1740,24 +1745,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config =stage2_best_config[15] config =stage2_best_config[15]
elif m<=64: elif m<=64:
config =stage2_best_config[16] config =stage2_best_config[16]
elif m<256: elif m<=256:
config ={ config =stage2_best_config[17]
"BLOCK_SIZE_M": 16, elif m<=1024:
"BLOCK_SIZE_N": 32, config =stage2_best_config[18]
"BLOCK_SIZE_K": 64, elif m<=8192:
"GROUP_SIZE_M": 1, config =stage2_best_config[19]
"num_stages": 0,
"num_warps": 4
}
else: else:
config ={ config =stage2_best_config[20]
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(intermediate_cache2,
w2, w2,
......
...@@ -68,7 +68,6 @@ def per_token_quant_int8(x): ...@@ -68,7 +68,6 @@ def per_token_quant_int8(x):
return x_q, scales return x_q, scales
@triton.jit @triton.jit
def _per_token_group_quant_int8( def _per_token_group_quant_int8(
# Pointers to inputs and output # Pointers to inputs and output
...@@ -76,9 +75,12 @@ def _per_token_group_quant_int8( ...@@ -76,9 +75,12 @@ def _per_token_group_quant_int8(
y_q_ptr, y_q_ptr,
y_s_ptr, y_s_ptr,
# Stride of input # Stride of input
y_stride, group_size,
# Collums of input # M,
N, # K,
# # Collums of input
# N,
SIZE,
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for int8 # Information for int8
...@@ -86,6 +88,7 @@ def _per_token_group_quant_int8( ...@@ -86,6 +88,7 @@ def _per_token_group_quant_int8(
int8_max, int8_max,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
s_num : tl.constexpr,
): ):
"""A Triton-accelerated function to perform """A Triton-accelerated function to perform
per-token-group quantization on a tensor. per-token-group quantization on a tensor.
...@@ -93,21 +96,26 @@ def _per_token_group_quant_int8( ...@@ -93,21 +96,26 @@ def _per_token_group_quant_int8(
""" """
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0) g_id = tl.program_id(0)
y_ptr += g_id * y_stride y_ptr += g_id * BLOCK
y_q_ptr += g_id * y_stride y_q_ptr += g_id * BLOCK
y_s_ptr += g_id y_s_ptr += g_id * s_num
cols = tl.arange(0, BLOCK) # N <= BLOCK cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N s_cols = tl.arange(0, s_num)
mask = g_id * BLOCK + cols < SIZE
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y = tl.reshape(y, (s_num, 128))
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
y_s = _absmax / int8_max y_s = (_absmax / int8_max).reshape(s_num, 1)
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
y_q = tl.reshape(y_q, (s_num*128))
y_s = tl.reshape(y_s, (s_num))
tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr + s_cols, y_s.to(y_s_ptr.dtype.element_ty))
def per_token_group_quant_int8( def per_token_group_quant_int8(
...@@ -139,30 +147,38 @@ def per_token_group_quant_int8( ...@@ -139,30 +147,38 @@ def per_token_group_quant_int8(
int8_min = iinfo.min int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size N = group_size
m = x.shape[0]
if m<=16:
config={"BLOCK":128,"s_num":1,"num_warps":1,"num_stages":1}
elif m<=256:
config={"BLOCK":1024,"s_num":8,"num_warps":4,"num_stages":1}
else:
config={"BLOCK":2048,"s_num":16,"num_warps":4,"num_stages":2}
grid = lambda META: (
triton.cdiv(x.numel(), META['BLOCK']),
)
x_s = torch.empty( x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,), x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device, device=x.device,
dtype=torch.float32, dtype=torch.float32,
) )
BLOCK = triton.next_power_of_2(N) _per_token_group_quant_int8[grid](
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M,)](
x, x,
x_q, x_q,
x_s, x_s,
group_size, group_size,
N, # M,
# K,
# N,
x.numel(),
eps, eps,
int8_min=int8_min, int8_min=int8_min,
int8_max=int8_max, int8_max=int8_max,
BLOCK=BLOCK, **config
num_warps=num_warps,
num_stages=num_stages,
) )
return x_q, x_s return x_q, x_s
...@@ -458,59 +474,6 @@ def w8a8_block_int8_matmul( ...@@ -458,59 +474,6 @@ def w8a8_block_int8_matmul(
return C return C
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def apply_w8a8_block_int8_linear( def apply_w8a8_block_int8_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
...@@ -69,7 +70,15 @@ class SampleResultArgsType: ...@@ -69,7 +70,15 @@ class SampleResultArgsType:
sampling_metadata: SamplingMetadata sampling_metadata: SamplingMetadata
greedy_samples: Optional[torch.Tensor] greedy_samples: Optional[torch.Tensor]
beam_search_logprobs: Optional[torch.Tensor] beam_search_logprobs: Optional[torch.Tensor]
# Implemented by guanyu
@dataclass
class SampleDeviceToDevices:
def __init__(self):
self.seq_id:torch.Tensor = None
self.sampled_token_ids_tensor:torch.Tensor = None
self.zero_overhead:bool = False
d2d_data = SampleDeviceToDevices()
# Union of non-deferred (single-step scheduling) # Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling) # vs deferred (multi-step scheduling)
...@@ -137,6 +146,9 @@ class SamplerOutput( ...@@ -137,6 +146,9 @@ class SamplerOutput(
# tree-style cartesian candidates # tree-style cartesian candidates
tree_attn_masks: Optional[torch.Tensor] = None tree_attn_masks: Optional[torch.Tensor] = None
sampler_out_tenosr : Optional[torch.Tensor] = None
sampler_out_ids : Optional[torch.Tensor] = None
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx] return self.outputs[idx]
...@@ -167,7 +179,10 @@ class SamplerOutput( ...@@ -167,7 +179,10 @@ class SamplerOutput(
f"sampled_token_ids={sampled_token_ids_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics}, " f"spec_decode_worker_metrics={self.spec_decode_worker_metrics}, "
f"logits={self.logits}, " f"logits={self.logits}, "
f"tree_attn_masks={self.tree_attn_masks})") f"tree_attn_masks={self.tree_attn_masks}, "
f"sampler_out_tenosr={self.sampler_out_tenosr}, "
f"sampler_out_ids={self.sampler_out_ids}, "
f")")
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -199,6 +214,8 @@ class Sampler(nn.Module): ...@@ -199,6 +214,8 @@ class Sampler(nn.Module):
# speculative decoding. # speculative decoding.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False self.should_modify_greedy_probs_inplace = False
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
d2d_data.zero_overhead = self.zero_overhead
def _init_sampling_tensors( def _init_sampling_tensors(
self, self,
...@@ -295,7 +312,6 @@ class Sampler(nn.Module): ...@@ -295,7 +312,6 @@ class Sampler(nn.Module):
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities. # Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
probs, probs,
...@@ -460,6 +476,7 @@ def _greedy_sample( ...@@ -460,6 +476,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
if not d2d_data.zero_overhead:
samples_lst = samples.tolist() samples_lst = samples.tolist()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -473,6 +490,10 @@ def _greedy_sample( ...@@ -473,6 +490,10 @@ def _greedy_sample(
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] #place holder token id
else:
next_token_ids = [samples_lst[sample_idx]] next_token_ids = [samples_lst[sample_idx]]
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
...@@ -496,6 +517,7 @@ def _random_sample( ...@@ -496,6 +517,7 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
# Find the maximum n value of the prompt phase requests. # Find the maximum n value of the prompt phase requests.
if not d2d_data.zero_overhead:
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
...@@ -511,11 +533,19 @@ def _random_sample( ...@@ -511,11 +533,19 @@ def _random_sample(
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
parent_ids = [0] * sampling_params.n parent_ids = [0] * sampling_params.n
if d2d_data.zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * sampling_params.n #place holder token id
else:
next_token_ids = random_samples[ next_token_ids = random_samples[
sample_idx, :sampling_params.n].tolist() sample_idx, :sampling_params.n].tolist()
else: else:
# Generation phase. # Generation phase.
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
if d2d_data.zero_overhead:
assert num_parent_seqs == 1 # not support muti seqences in seqence group
next_token_ids = [0] * num_parent_seqs #place holder token id
else:
next_token_ids = random_samples[sample_idx:sample_idx + next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist() num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
...@@ -689,7 +719,6 @@ def get_pythonized_sample_results( ...@@ -689,7 +719,6 @@ def get_pythonized_sample_results(
sample_result_args.beam_search_logprobs, sample_result_args.beam_search_logprobs,
sample_result_args.sample_results_dict, sample_result_args.sample_results_dict,
) )
for sampling_type in SamplingType: for sampling_type in SamplingType:
if sampling_type not in sample_metadata: if sampling_type not in sample_metadata:
continue continue
...@@ -734,12 +763,13 @@ def _sample_with_torch( ...@@ -734,12 +763,13 @@ def _sample_with_torch(
t: [] t: []
for t in SamplingType for t in SamplingType
} }
d2d_data.seq_id = torch.zeros(len(sampling_metadata.seq_groups), dtype=torch.int32)
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
d2d_data.seq_id[i] = seq_group.seq_ids[0]
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: SampleResultsDictType = {} sample_results_dict: SampleResultsDictType = {}
sample_metadata: SampleMetadataType = {} sample_metadata: SampleMetadataType = {}
multinomial_samples: MultinomialSamplesType = {} multinomial_samples: MultinomialSamplesType = {}
...@@ -771,6 +801,9 @@ def _sample_with_torch( ...@@ -771,6 +801,9 @@ def _sample_with_torch(
greedy_samples = torch.argmax(logprobs[long_sample_indices], greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1) dim=-1)
if d2d_data.zero_overhead:
d2d_data.sampled_token_ids_tensor = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[ sampled_token_ids_tensor[
...@@ -808,6 +841,10 @@ def _sample_with_torch( ...@@ -808,6 +841,10 @@ def _sample_with_torch(
max_n_in_batch, max_n_in_batch,
seq_groups=seq_groups_arg) seq_groups=seq_groups_arg)
if d2d_data.zero_overhead:
d2d_data.sampled_token_ids_tensor = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[long_sample_indices] = \ sampled_token_ids_tensor[long_sample_indices] = \
...@@ -1271,7 +1308,9 @@ def _build_sampler_output( ...@@ -1271,7 +1308,9 @@ def _build_sampler_output(
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor, logprobs=logprobs_tensor,
deferred_sample_results_args=deferred_sample_results_args, deferred_sample_results_args=deferred_sample_results_args,
logits=logits) logits=logits,
sampler_out_tenosr = d2d_data.sampled_token_ids_tensor,
sampler_out_ids = d2d_data.seq_id)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
......
import torch
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
sample_output,
seq_ids,
input_tokens,
input_seq_ids,
BATCH_SIZE1,
BATCH_SIZE2,
):
pid = tl.program_id(0)
if pid >= BATCH_SIZE2:
return
output_token = tl.load(input_tokens + pid)
_input_seq_id = tl.load(input_seq_ids + pid)
for i in range(BATCH_SIZE1):
_seq_ids = tl.load(seq_ids + i)
if _seq_ids == _input_seq_id:
output_token = tl.load(sample_output + i)
tl.store(input_tokens + pid, output_token)
def UpdateInputTokens(input_tokens, input_seq_ids, last_sample, last_ids):
grid = [input_seq_ids.shape[0], 1, 1]
_update_input_tokens[grid](last_sample, last_ids, input_tokens, input_seq_ids, last_ids.shape[0], input_seq_ids.shape[0])
\ No newline at end of file
...@@ -514,7 +514,6 @@ class SamplingTensors: ...@@ -514,7 +514,6 @@ class SamplingTensors:
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
do_penalties = prompt_tokens or output_tokens do_penalties = prompt_tokens or output_tokens
if do_penalties: if do_penalties:
prompt_t = make_tensor_with_pad( prompt_t = make_tensor_with_pad(
prompt_tokens, prompt_tokens,
...@@ -534,7 +533,6 @@ class SamplingTensors: ...@@ -534,7 +533,6 @@ class SamplingTensors:
empty_tensor = torch.empty(0, device=device, dtype=torch.long) empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor prompt_t = empty_tensor
output_t = empty_tensor output_t = empty_tensor
temperatures_t = torch.tensor( temperatures_t = torch.tensor(
temperatures, temperatures,
device="cpu", device="cpu",
......
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_nvtx = os.getenv('VLLM_PROF_NVTX') is not None
self.roc_tracer_flag = False
self.lib = None
if self.use_nvtx:
self.lib = cdll.LoadLibrary("libnvToolsExt.so")
self.lib.nvtxRangePushA.argtypes = [c_char_p]
self.lib.nvtxRangePushA.restype = c_int
self.lib.nvtxRangePop.restype = c_int
self.use_roctx = os.getenv('VLLM_PROF_ROCTX') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_nvtx:
profile.lib.nvtxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_nvtx:
if not self.thread_depth_add(-1):
return
profile.lib.nvtxRangePop()
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
...@@ -7,6 +7,7 @@ from array import array ...@@ -7,6 +7,7 @@ from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import reduce from functools import reduce
import os
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union from typing import Set, Tuple, Union
...@@ -178,6 +179,8 @@ class SequenceData(msgspec.Struct, ...@@ -178,6 +179,8 @@ class SequenceData(msgspec.Struct,
_first_step_flag: bool = True _first_step_flag: bool = True
_effective_length: int = 0
@staticmethod @staticmethod
def from_prompt_token_counts( def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData": *token_counts: Tuple[int, int]) -> "SequenceData":
...@@ -308,15 +311,30 @@ class SequenceData(msgspec.Struct, ...@@ -308,15 +311,30 @@ class SequenceData(msgspec.Struct,
self._cached_all_token_ids.append(token_id) self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob self._cumulative_logprob += logprob
def fix_effective_token_id(self, token_id: int,):
effect_offset = self._effective_length - len(self.output_token_ids)
if effect_offset < 0:
self._output_token_ids[effect_offset] = token_id
if len(self._new_appended_tokens) >= effect_offset * -1:
self._new_appended_tokens[effect_offset] = token_id
self._cached_all_token_ids[effect_offset] = token_id
self._effective_length += 1
def get_len(self) -> int: def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids) return len(self._output_token_ids) + len(self._prompt_token_ids)
def zero_overhead_get_len(self) -> int:
return self._effective_length + len(self._prompt_token_ids)
def get_prompt_len(self) -> int: def get_prompt_len(self) -> int:
return len(self._prompt_token_ids) return len(self._prompt_token_ids)
def get_output_len(self) -> int: def get_output_len(self) -> int:
return len(self._output_token_ids) return len(self._output_token_ids)
def zero_overhead_get_output_len(self) -> int:
return self._effective_length
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
return self._cached_all_token_ids return self._cached_all_token_ids
...@@ -367,15 +385,22 @@ class SequenceData(msgspec.Struct, ...@@ -367,15 +385,22 @@ class SequenceData(msgspec.Struct,
# of prompt_len here. This is because during recompute we need to # of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output. # prefill for both prompt and output.
return self.get_len() - self.get_num_computed_tokens() return self.get_len() - self.get_num_computed_tokens()
def get_last_token_id(self) -> int: def get_last_token_id(self) -> int:
if not self._output_token_ids: if not self._output_token_ids:
return self._prompt_token_ids[-1] return self._prompt_token_ids[-1]
return self._output_token_ids[-1] return self._output_token_ids[-1]
def zero_overhead_get_last_token_id(self) -> int:
if self._effective_length == 0:
return self._prompt_token_ids[-1]
return self._output_token_ids[self._effective_length - 1]
def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids return self.prompt_token_ids
def zero_overhead_get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids[:self._effective_length]
def get_output_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids return self.output_token_ids
...@@ -461,6 +486,7 @@ class Sequence: ...@@ -461,6 +486,7 @@ class Sequence:
self.read_offset = 0 self.read_offset = 0
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
@property @property
def n_blocks(self) -> int: def n_blocks(self) -> int:
...@@ -527,9 +553,9 @@ class Sequence: ...@@ -527,9 +553,9 @@ class Sequence:
"""If delta is True, only new tokens since the last call to """If delta is True, only new tokens since the last call to
this method are returned""" this method are returned"""
if not delta: if not delta:
return self.get_output_token_ids() return self.get_output_token_ids(self.zero_overhead)
output_len = self.get_output_len() output_len = self.get_output_len(self.zero_overhead)
# Get the number of new tokens # Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset num_new_tokens = output_len - self._last_output_token_ids_offset
...@@ -539,11 +565,16 @@ class Sequence: ...@@ -539,11 +565,16 @@ class Sequence:
if num_new_tokens == 1: if num_new_tokens == 1:
# Optimization for single decode token case # Optimization for single decode token case
# (which is what we have most of the time) # (which is what we have most of the time)
if self.zero_overhead:
return self.data._cached_all_token_ids[self.data._effective_length - 1]
else:
return self.data._cached_all_token_ids[-1] return self.data._cached_all_token_ids[-1]
if num_new_tokens == 0: if num_new_tokens == 0:
return [] return []
if self.zero_overhead:
return self.data._cached_all_token_ids[-num_new_tokens : self.data._effective_length]
return self.data._cached_all_token_ids[-num_new_tokens:] return self.data._cached_all_token_ids[-num_new_tokens:]
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
...@@ -582,13 +613,20 @@ class Sequence: ...@@ -582,13 +613,20 @@ class Sequence:
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob)
def get_len(self) -> int: def fix_last_token_id(self, token_id: int) -> None:
self.data.fix_effective_token_id(token_id)
def get_len(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_len()
return self.data.get_len() return self.data.get_len()
def get_prompt_len(self) -> int: def get_prompt_len(self) -> int:
return self.data.get_prompt_len() return self.data.get_prompt_len()
def get_output_len(self) -> int: def get_output_len(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_output_len()
return self.data.get_output_len() return self.data.get_output_len()
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
...@@ -597,10 +635,14 @@ class Sequence: ...@@ -597,10 +635,14 @@ class Sequence:
def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.data.get_prompt_token_ids() return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int: def get_last_token_id(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_last_token_id()
return self.data.get_last_token_id() return self.data.get_last_token_id()
def get_output_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self, zero_overhead = False) -> Tuple[int, ...]:
if zero_overhead:
return self.data.zero_overhead_get_output_token_ids()
return self.data.get_output_token_ids() return self.data.get_output_token_ids()
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
...@@ -807,6 +849,7 @@ class SequenceGroup: ...@@ -807,6 +849,7 @@ class SequenceGroup:
def set_last_token_time(self, now: float) -> None: def set_last_token_time(self, now: float) -> None:
"""Sets the last token time for Request level timings.""" """Sets the last token time for Request level timings."""
# If still in prefill phase, assertion fails. # If still in prefill phase, assertion fails.
if not self.seqs[0].zero_overhead:
assert not self.is_prefill(), ( assert not self.is_prefill(), (
"seq_group.set_last_token_time() should not be called " "seq_group.set_last_token_time() should not be called "
"if the seq_group is in prefill phase.") "if the seq_group is in prefill phase.")
...@@ -815,6 +858,7 @@ class SequenceGroup: ...@@ -815,6 +858,7 @@ class SequenceGroup:
def get_last_token_latency(self) -> float: def get_last_token_latency(self) -> float:
"""Returns the latency of the last token.""" """Returns the latency of the last token."""
if not self.seqs[0].zero_overhead:
assert not self.is_prefill(), ( assert not self.is_prefill(), (
"seq_group.get_last_token_latency() should not be called " "seq_group.get_last_token_latency() should not be called "
"if the seq_group is in prefill phase.") "if the seq_group is in prefill phase.")
...@@ -1402,6 +1446,12 @@ class ExecuteModelRequest( ...@@ -1402,6 +1446,12 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model. # Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved: Optional[torch.Tensor] = None kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_sample : Optional[torch.Tensor] = None
# for zero-overhead scheduler
last_outputs_ids : Optional[torch.Tensor] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of # TODO(will) make this be able to handle batches with variable number of
...@@ -1451,7 +1501,9 @@ class ExecuteModelRequest( ...@@ -1451,7 +1501,9 @@ class ExecuteModelRequest(
async_callback=self.async_callback, async_callback=self.async_callback,
tree_attn_masks=self.tree_attn_masks, tree_attn_masks=self.tree_attn_masks,
tree_position_ids=self.tree_position_ids, tree_position_ids=self.tree_position_ids,
kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved) kvcache_slot_to_be_moved=self.kvcache_slot_to_be_moved,
last_outputs_sample = self.last_outputs_sample,
last_outputs_ids = self.last_outputs_ids)
@dataclass @dataclass
......
...@@ -690,6 +690,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -690,6 +690,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states = hidden_states[ hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids - torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]] VLLM_INVALID_TOKEN_ID)[0]]
if not skip_proposer:
if self.previous_hidden_states is None and len( if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden): seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates( self.previous_hidden_states = HiddenStates(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from typing import List, Optional
import torch
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import (ModelRunnerBase, from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase, ModelRunnerInputBase,
...@@ -31,10 +31,12 @@ class TargetModelRunner(ModelRunnerWrapperBase): ...@@ -31,10 +31,12 @@ class TargetModelRunner(ModelRunnerWrapperBase):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> ModelRunnerInputBase: ) -> ModelRunnerInputBase:
model_input: ModelRunnerInputBase =\ model_input: ModelRunnerInputBase =\
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
seq_group_metadata_list, virtual_engine, finished_requests_ids) seq_group_metadata_list, virtual_engine, finished_requests_ids, last_outputs_ids, last_output_sample)
# If token log probabilities is disabled then skip generating sampler # If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors # CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the # as needed. If log probabilities is enabled then synchronize all the
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
...@@ -16,6 +17,7 @@ class Detokenizer: ...@@ -16,6 +17,7 @@ class Detokenizer:
def __init__(self, tokenizer_group: BaseTokenizerGroup): def __init__(self, tokenizer_group: BaseTokenizerGroup):
self.tokenizer_group = tokenizer_group self.tokenizer_group = tokenizer_group
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence.""" """Returns the HF tokenizer to use for a given sequence."""
...@@ -108,6 +110,10 @@ class Detokenizer: ...@@ -108,6 +110,10 @@ class Detokenizer:
The number of characters added to the output text. The number of characters added to the output text.
""" """
all_input_ids = seq.get_token_ids() all_input_ids = seq.get_token_ids()
if self.zero_overhead:
eff_length = seq.get_prompt_len() + seq.data._effective_length
all_input_ids = seq.get_token_ids()[ : eff_length]
token_id_generated_this_iteration = all_input_ids[-1] token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq) tokenizer = self.get_tokenizer_for_seq(seq)
......
# SPDX-License-Identifier: Apache-2.0
try: try:
from ._version import __version__, __version_tuple__ __version__ = "0.7.2"
__version_tuple__ = (0, 7, 2)
__hcu_version__ = f'0.7.2+das.opt1.cust1.6b7651a.dtk2504'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e: except Exception as e:
import warnings import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning, RuntimeWarning,
stacklevel=2) stacklevel=2)
__version__ = "dev" __version__ = "dev"
__version_tuple__ = (0, 0, __version__) __version_tuple__ = (0, 0, __version__)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import sys
import dataclasses import dataclasses
import gc import gc
import inspect import inspect
import itertools import itertools
import os
import time import time
import weakref import weakref
from contextlib import contextmanager from contextlib import contextmanager
...@@ -59,6 +61,8 @@ from vllm.worker.model_runner_base import ( ...@@ -59,6 +61,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.model_executor.layers.update_input import UpdateInputTokens
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -271,7 +275,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -271,7 +275,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs self.n_seqs = n_seqs
self.encoder_seq_len = encoder_seq_len self.encoder_seq_len = encoder_seq_len
if reinit: if reinit:
if len(self.seq_ids) == 1 and reinit_use_defaults: if len(self.seq_ids) == 1 and reinit_use_defaults:
self.simple_reinit() self.simple_reinit()
...@@ -475,6 +478,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -475,6 +478,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.sliding_window_blocks * self.block_size self.sliding_window_blocks * self.block_size
self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
self.last_sample_tensor = None
self.last_sample_ids = None
self.req_ids = []
def SetLastSamperData(self, last_sample_ids, last_sample_tensor):
self.last_sample_tensor = last_sample_tensor
self.last_sample_ids = last_sample_ids
def prepare(self, def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None: finished_requests_ids: Optional[List[str]] = None) -> None:
...@@ -490,6 +501,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -490,6 +501,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
ModelInputForGPUBuilder.InterDataForSeqGroup] = [] ModelInputForGPUBuilder.InterDataForSeqGroup] = []
self.attn_metadata_builder.prepare() self.attn_metadata_builder.prepare()
self.req_ids.clear()
def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
...@@ -755,8 +767,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -755,8 +767,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len=encoder_seq_len) encoder_seq_len=encoder_seq_len)
self.inter_data_list.append(inter_data) self.inter_data_list.append(inter_data)
seq_ids = list(seq_ids)
for seq_idx in range(n_seqs): for seq_idx in range(n_seqs):
self.req_ids.append(seq_ids[seq_idx])
for per_seq_fn in self.per_seq_compute_fns: for per_seq_fn in self.per_seq_compute_fns:
per_seq_fn(inter_data, seq_idx, seq_group_metadata) per_seq_fn(inter_data, seq_idx, seq_group_metadata)
for per_seq_group_fn in self.per_seq_group_compute_fns: for per_seq_group_fn in self.per_seq_group_compute_fns:
...@@ -897,10 +910,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -897,10 +910,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if cuda_graph_pad_size: if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
if self.zero_overhead and self.last_sample_tensor is not None:
input_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
last_ids = async_tensor_h2d(self.last_sample_ids.tolist(), torch.long,
self.runner.device,
self.runner.pin_memory)
UpdateInputTokens(input_tokens_tensor, input_ids, self.last_sample_tensor, last_ids)
token_types_tensor = async_tensor_h2d(token_types, torch.long, token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) \ self.runner.pin_memory) \
...@@ -1109,6 +1132,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1109,6 +1132,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# multi-step model runner does not have `_builder_cls` # multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self)) self.builder = self._builder_cls(weakref.proxy(self))
self.enforce_eager_bs_threshould = sys.maxsize
if envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD is not None and envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD > 0:
self.enforce_eager_bs_threshould = envs.VLLM_ENFORCE_EAGER_BS_THRESHOLD
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
...@@ -1198,7 +1225,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1198,7 +1225,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def _prepare_model_input_tensors( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> TModelInputForGPU: ) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence """Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not group. Prepares metadata needed for the base model forward pass but not
...@@ -1219,7 +1248,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1219,7 +1248,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.builder.add_seq_group(seq_group_metadata) self.builder.add_seq_group(seq_group_metadata)
self.builder.reset_cached_inter_data() self.builder.reset_cached_inter_data()
self.builder.SetLastSamperData(last_outputs_ids, last_output_sample)
return self.builder.build() # type: ignore return self.builder.build() # type: ignore
@contextmanager @contextmanager
...@@ -1614,6 +1643,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1614,6 +1643,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
...@@ -1629,7 +1660,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1629,7 +1660,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids) seq_group_metadata_list, finished_requests_ids, last_outputs_ids, last_output_sample)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group # Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids) generators = self.get_generators(finished_requests_ids)
...@@ -1670,7 +1701,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1670,7 +1701,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_prompt_adapters( self.set_active_prompt_adapters(
model_input.prompt_adapter_requests, model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping) model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input) self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
...@@ -1680,7 +1710,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1680,7 +1710,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# TODO(andoorve): We can remove this once all # TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache. # virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph and \
model_input.input_tokens.shape[0] <= self.enforce_eager_bs_threshould:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][ model_executable = self.graph_runners[virtual_engine][
......
...@@ -210,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -210,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]):
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None, finished_requests_ids: Optional[List[str]] = None,
last_outputs_ids: torch.Tensor = None,
last_output_sample: torch.Tensor = None,
) -> T: ) -> T:
""" """
Prepare the inputs to ModelRunnerBase.execute_model from an execution Prepare the inputs to ModelRunnerBase.execute_model from an execution
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import dataclasses import dataclasses
import os import os
import numa
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
...@@ -28,6 +29,23 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput, ...@@ -28,6 +29,23 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
logger = init_logger(__name__) logger = init_logger(__name__)
# 设置当前进程绑定到 NUMA 节点
def bind_to_numa(local_rank):
env_str = f"VLLM_RANK{local_rank}_NUMA"
node_count = numa.get_max_node() + 1
numa_node = int(os.getenv(env_str, -1))
# 未配置环境变量或配置错误则不做绑定,TODO:根据topo自动绑定方案
if numa_node < 0:
logger.warning("%s is unset or set incorrectly, vllm will not bind to numa! %s = %d", env_str, env_str, numa_node)
return
if numa_node > numa.get_max_node():
raise ValueError(f"NUMA node {numa_node} is not available.")
numa.bind([numa_node])
class WorkerBase(ABC): class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for """Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to different hardware. Also abstracts control plane communication, e.g., to
...@@ -356,7 +374,9 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -356,7 +374,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine, execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids)) execute_model_req.finished_requests_ids,
last_outputs_ids = execute_model_req.last_outputs_ids,
last_output_sample = execute_model_req.last_outputs_sample))
if self.tree_decoding and execute_model_req.tree_position_ids is not None and \ if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
execute_model_req.tree_attn_masks is not None: execute_model_req.tree_attn_masks is not None:
...@@ -444,7 +464,6 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -444,7 +464,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
and self.observability_config.collect_model_execute_time): and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get( orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item() "model_execute_time", torch.tensor(0)).item()
output = self.model_runner.execute_model( output = self.model_runner.execute_model(
model_input=model_input, model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine] kv_caches=self.kv_cache[worker_input.virtual_engine]
...@@ -595,6 +614,16 @@ class WorkerWrapperBase: ...@@ -595,6 +614,16 @@ class WorkerWrapperBase:
self.worker = worker_class(**kwargs) self.worker = worker_class(**kwargs)
assert self.worker is not None assert self.worker is not None
VLLM_NUMA_BIND = int(os.getenv("VLLM_NUMA_BIND", 1))
if VLLM_NUMA_BIND > 0:
# 绑定当前进程到指定 NUMA 节点
bind_to_numa(kwargs['local_rank'])
pid = os.getpid()
logger.info("########## %d process(rank%s) is running on CPU(s): %s", pid, str(kwargs['local_rank']), str(os.sched_getaffinity(pid)))
logger.info("########## %d process(rank%s) is running on memnode(s): %s", pid, str(kwargs['local_rank']), str(numa.get_membind()))
def execute_method(self, method: Union[str, bytes], *args, **kwargs): def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try: try:
target = self if self.worker is None else self.worker target = self if self.worker is None else self.worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment