Unverified Commit 39029d51 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files
parent 35d801f1
...@@ -14,6 +14,7 @@ from dataclasses import dataclass ...@@ -14,6 +14,7 @@ from dataclasses import dataclass
from typing import Literal, NamedTuple from typing import Literal, NamedTuple
import pytest import pytest
import torch
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -254,6 +255,17 @@ def test_cp_generation( ...@@ -254,6 +255,17 @@ def test_cp_generation(
test_options: CPTestOptions, test_options: CPTestOptions,
num_gpus_available, num_gpus_available,
): ):
if (
model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
and torch.cuda.get_device_capability() < (9, 0)
):
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
if (
model_id == "bigcode/gpt_bigcode-santacoder"
and torch.cuda.get_device_capability() != (9, 0)
):
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
_compare_cp_with_tp( _compare_cp_with_tp(
model_id, model_id,
parallel_setup, parallel_setup,
......
...@@ -195,7 +195,6 @@ def cp_lse_ag_out_rs( ...@@ -195,7 +195,6 @@ def cp_lse_ag_out_rs(
cp_attn_lse = cp_attn_lse.contiguous() cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1) out = cp_group.reduce_scatter(out, dim=1)
if return_lse: if return_lse:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment