Unverified Commit 32c9be22 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[v1] Re-add fp32 support to v1 engine through FlexAttention (#19754)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 8aeaa910
...@@ -68,7 +68,7 @@ jobs: ...@@ -68,7 +68,7 @@ jobs:
export AWS_ACCESS_KEY_ID=minioadmin export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
- name: curl test - name: curl test
run: | run: |
......
...@@ -181,6 +181,34 @@ def test_env( ...@@ -181,6 +181,34 @@ def test_env(
assert backend.get_name() == expected assert backend.get_name() == expected
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
def test_fp32_fallback(
device: str,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Test attention backend selection with fp32."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to # TODO: When testing for v1, pipe in `use_v1` as an argument to
......
...@@ -450,6 +450,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): ...@@ -450,6 +450,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer" error_msg = f"{layer_1} must come before the current layer"
...@@ -478,6 +479,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): ...@@ -478,6 +479,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn" invalid_layer = "model.layers.0.cross_attn.attn"
...@@ -506,6 +508,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): ...@@ -506,6 +508,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(): def test_init_kv_cache_with_kv_sharing_target_same_as_current():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer" error_msg = f"{layer_1} cannot be the same as the current layer"
...@@ -534,6 +537,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): ...@@ -534,6 +537,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_without_kv_sharing(): def test_init_kv_cache_without_kv_sharing():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config() vllm_config = get_vllm_config()
...@@ -601,6 +605,7 @@ def test_init_kv_cache_without_kv_sharing(): ...@@ -601,6 +605,7 @@ def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_with_kv_sharing_valid(): def test_init_kv_cache_with_kv_sharing_valid():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config() vllm_config = get_vllm_config()
......
...@@ -1393,13 +1393,6 @@ class EngineArgs: ...@@ -1393,13 +1393,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# Only Fp16 and Bf16 dtypes since we only support FA.
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
if model_config.dtype not in V1_SUPPORTED_DTYPES:
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
recommend_to_remove=False)
return False
# No Mamba or Encoder-Decoder so far. # No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible: if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures, _raise_or_fallback(feature_name=model_config.architectures,
......
...@@ -104,8 +104,12 @@ class TensorizerLoader(BaseModelLoader): ...@@ -104,8 +104,12 @@ class TensorizerLoader(BaseModelLoader):
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config) tensorizer_config = self._patch_tensorizer_config(model_config)
model = init_tensorizer_model(tensorizer_config=tensorizer_config, device_config = vllm_config.device_config
vllm_config=vllm_config) with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = init_tensorizer_model(
tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
self.load_weights(model, model_config) self.load_weights(model, model_config)
return model return model
return self._load_model_serialized_cpu(vllm_config=vllm_config) return self._load_model_serialized_cpu(vllm_config=vllm_config)
......
...@@ -251,6 +251,10 @@ class CudaPlatformBase(Platform): ...@@ -251,6 +251,10 @@ class CudaPlatformBase(Platform):
# Default backends for V1 engine # Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed # Prefer FlashInfer for Blackwell GPUs if installed
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if cls.is_device_capability(100): if cls.is_device_capability(100):
try: try:
import flashinfer # noqa: F401 import flashinfer # noqa: F401
......
...@@ -463,6 +463,13 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -463,6 +463,13 @@ class FlexAttentionImpl(AttentionImpl):
query = query[:, :, :num_actual_tokens, :] query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation # Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2) # torch._dynamo.try_mark_dynamic(query, 2)
# default M=64, N=64 may run out of shared memory on
# some GPUs with fp32, so we use smaller M and N.
extra_kernel_options = {
"BLOCK_M": 32,
"BLOCK_N": 32
} if query.dtype == torch.float32 else {}
out = flex_attention_compiled( out = flex_attention_compiled(
query, query,
key_cache, key_cache,
...@@ -471,7 +478,10 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -471,7 +478,10 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata.block_mask, attn_metadata.block_mask,
self.scale, self.scale,
enable_gqa=enable_gqa, enable_gqa=enable_gqa,
kernel_options={"FORCE_USE_FLEX_ATTENTION": True}, kernel_options={
"FORCE_USE_FLEX_ATTENTION": True,
**extra_kernel_options
},
) )
# Flex doesn't have an out variant today, rely on epilogue fusion # Flex doesn't have an out variant today, rely on epilogue fusion
......
...@@ -101,7 +101,10 @@ class TopKTopPSampler(nn.Module): ...@@ -101,7 +101,10 @@ class TopKTopPSampler(nn.Module):
"per-request generators. Falling back to " "per-request generators. Falling back to "
"PyTorch-native implementation.") "PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p) return self.forward_native(logits, generators, k, p)
return flashinfer_sample(logits, k, p, generators) # flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators)
def forward_tpu( def forward_tpu(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment