Unverified Commit eebad39f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[torch.compile] support all attention backends (#10558)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent db100c5c
...@@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module): ...@@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module): ...@@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config, self.self_attn = Starcoder2Attention(config,
cache_config, cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon) eps=config.norm_epsilon)
...@@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module): ...@@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Starcoder2DecoderLayer( lambda prefix: Starcoder2DecoderLayer(
config, cache_config, quant_config=quant_config), config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
......
...@@ -93,6 +93,7 @@ class XverseAttention(nn.Module): ...@@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -138,7 +139,8 @@ class XverseAttention(nn.Module): ...@@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward( def forward(
self, self,
...@@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
bias=getattr(config, "bias", False), bias=getattr(config, "bias", False),
cache_config=cache_config, cache_config=cache_config,
prefix=f"{prefix}.self_attn",
) )
self.mlp = XverseMLP( self.mlp = XverseMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -243,8 +247,8 @@ class XverseModel(nn.Module): ...@@ -243,8 +247,8 @@ class XverseModel(nn.Module):
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: XverseDecoderLayer(config, cache_config, lambda prefix: XverseDecoderLayer(
quant_config), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......
...@@ -20,6 +20,7 @@ logger = init_logger(__name__) ...@@ -20,6 +20,7 @@ logger = init_logger(__name__)
class CpuPlatform(Platform): class CpuPlatform(Platform):
_enum = PlatformEnum.CPU _enum = PlatformEnum.CPU
device_type: str = "cpu" device_type: str = "cpu"
dispatch_key: str = "CPU"
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
......
...@@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int: ...@@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class CudaPlatform(Platform): class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA _enum = PlatformEnum.CUDA
device_type: str = "cuda" device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod @classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
......
...@@ -13,6 +13,7 @@ else: ...@@ -13,6 +13,7 @@ else:
class HpuPlatform(Platform): class HpuPlatform(Platform):
_enum = PlatformEnum.HPU _enum = PlatformEnum.HPU
device_type: str = "hpu" device_type: str = "hpu"
dispatch_key: str = "HPU"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
...@@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple): ...@@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
class Platform: class Platform:
_enum: PlatformEnum _enum: PlatformEnum
device_type: str device_type: str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
def is_cuda(self) -> bool: def is_cuda(self) -> bool:
return self._enum == PlatformEnum.CUDA return self._enum == PlatformEnum.CUDA
......
...@@ -18,6 +18,7 @@ logger = init_logger(__name__) ...@@ -18,6 +18,7 @@ logger = init_logger(__name__)
class OpenVinoPlatform(Platform): class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO _enum = PlatformEnum.OPENVINO
device_type: str = "openvino" device_type: str = "openvino"
dispatch_key: str = "CPU"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
...@@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: ...@@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
device_type: str = "cuda" device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
...@@ -17,6 +17,7 @@ logger = init_logger(__name__) ...@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
device_type: str = "tpu" device_type: str = "tpu"
dispatch_key: str = "XLA"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
...@@ -17,6 +17,7 @@ logger = init_logger(__name__) ...@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class XPUPlatform(Platform): class XPUPlatform(Platform):
_enum = PlatformEnum.XPU _enum = PlatformEnum.XPU
device_type: str = "xpu" device_type: str = "xpu"
dispatch_key: str = "XPU"
@classmethod @classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
......
...@@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
if previous_hidden_states is not None else {} if previous_hidden_states is not None else {}
# Run model # Run model
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata,
self.vllm_config):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
......
...@@ -1573,6 +1573,7 @@ def direct_register_custom_op( ...@@ -1573,6 +1573,7 @@ def direct_register_custom_op(
mutates_args: List[str], mutates_args: List[str],
fake_impl: Optional[Callable] = None, fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None, target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
): ):
""" """
`torch.library.custom_op` can have significant overhead because it `torch.library.custom_op` can have significant overhead because it
...@@ -1601,7 +1602,7 @@ def direct_register_custom_op( ...@@ -1601,7 +1602,7 @@ def direct_register_custom_op(
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str) my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA") my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None: if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl) my_lib._register_fake(op_name, fake_impl)
......
...@@ -173,7 +173,8 @@ def unified_v1_flash_attention( ...@@ -173,7 +173,8 @@ def unified_v1_flash_attention(
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
) -> None: ) -> None:
current_metadata = get_forward_context() context = get_forward_context()
current_metadata = context.dynamic_forward_context
if current_metadata is None: if current_metadata is None:
# Profiling run. # Profiling run.
return return
......
...@@ -447,7 +447,7 @@ class GPUModelRunner: ...@@ -447,7 +447,7 @@ class GPUModelRunner:
# Run the decoder. # Run the decoder.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model( hidden_states = self.model(
input_ids=None, input_ids=None,
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
...@@ -523,7 +523,7 @@ class GPUModelRunner: ...@@ -523,7 +523,7 @@ class GPUModelRunner:
num_tokens: int, num_tokens: int,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
with set_forward_context(None): with set_forward_context(None, self.vllm_config):
hidden_states = model( hidden_states = model(
input_ids=None, input_ids=None,
positions=self.positions[:num_tokens], positions=self.positions[:num_tokens],
......
...@@ -97,7 +97,7 @@ class EmbeddingModelRunner( ...@@ -97,7 +97,7 @@ class EmbeddingModelRunner(
model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record() model_forward_start.record()
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
......
...@@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
} if self.has_inner_state else {} } if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
......
...@@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model( self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs) capture_inputs)
with set_forward_context(attn_metadata): with set_forward_context(attn_metadata, self.vllm_config):
graph_runner.capture(**capture_inputs) graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][batch_size] = (
...@@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record() model_forward_start.record()
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment