Unverified Commit 39029d51 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files
parent 35d801f1
......@@ -14,6 +14,7 @@ from dataclasses import dataclass
from typing import Literal, NamedTuple
import pytest
import torch
from vllm.config.model import RunnerOption
from vllm.logger import init_logger
......@@ -254,6 +255,17 @@ def test_cp_generation(
test_options: CPTestOptions,
num_gpus_available,
):
if (
model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
and torch.cuda.get_device_capability() < (9, 0)
):
pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
if (
model_id == "bigcode/gpt_bigcode-santacoder"
and torch.cuda.get_device_capability() != (9, 0)
):
pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")
_compare_cp_with_tp(
model_id,
parallel_setup,
......
......@@ -195,7 +195,6 @@ def cp_lse_ag_out_rs(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)
if return_lse:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment