Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
......@@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
......@@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config,
cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
......@@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Starcoder2DecoderLayer(
config, cache_config, quant_config=quant_config),
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers",
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
......
......@@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
......@@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
quant_config=quant_config,
bias=getattr(config, "bias", False),
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = XverseMLP(
hidden_size=self.hidden_size,
......@@ -243,8 +247,8 @@ class XverseModel(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: XverseDecoderLayer(config, cache_config,
quant_config),
lambda prefix: XverseDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
......@@ -20,6 +20,7 @@ logger = init_logger(__name__)
class CpuPlatform(Platform):
_enum = PlatformEnum.CPU
device_type: str = "cpu"
dispatch_key: str = "CPU"
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
......
......@@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA
device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
......
......@@ -13,6 +13,7 @@ else:
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
device_type: str = "hpu"
dispatch_key: str = "HPU"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
......@@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
class Platform:
_enum: PlatformEnum
device_type: str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA
......
......@@ -18,6 +18,7 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO
device_type: str = "openvino"
dispatch_key: str = "CPU"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
......@@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
......@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_type: str = "tpu"
dispatch_key: str = "XLA"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
......@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class XPUPlatform(Platform):
_enum = PlatformEnum.XPU
device_type: str = "xpu"
dispatch_key: str = "XPU"
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
......@@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
if previous_hidden_states is not None else {}
# Run model
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......
......@@ -1573,6 +1573,7 @@ def direct_register_custom_op(
mutates_args: List[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
):
"""
`torch.library.custom_op` can have significant overhead because it
......@@ -1601,7 +1602,7 @@ def direct_register_custom_op(
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
......
......@@ -173,7 +173,8 @@ def unified_v1_flash_attention(
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
current_metadata = get_forward_context()
context = get_forward_context()
current_metadata = context.dynamic_forward_context
if current_metadata is None:
# Profiling run.
return
......
......@@ -447,7 +447,7 @@ class GPUModelRunner:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=None,
positions=self.positions[:num_input_tokens],
......@@ -523,7 +523,7 @@ class GPUModelRunner:
num_tokens: int,
kv_caches: List[torch.Tensor],
) -> torch.Tensor:
with set_forward_context(None):
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=None,
positions=self.positions[:num_tokens],
......
......@@ -97,7 +97,7 @@ class EmbeddingModelRunner(
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......
......@@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
} if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......
......@@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)
with set_forward_context(attn_metadata):
with set_forward_context(attn_metadata, self.vllm_config):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
......@@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
with set_forward_context(model_input.attn_metadata):
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment