Unverified Commit 2a15840f authored by Grigory Sizov's avatar Grigory Sizov Committed by GitHub
Browse files

Enable paged attention in varlen forward (#831)

* Enable paged attention in varlen forward

* Format + fix padding
parent 26c9e827
...@@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
std::vector<at::Tensor> std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q, int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
...@@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k); CHECK_DEVICE(cu_seqlens_k);
at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
...@@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int batch_size = cu_seqlens_q.numel() - 1; const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1]; int num_heads = sizes[1];
const int head_size_og = sizes[2]; const int head_size_og = sizes[2];
const int total_k = k.size(0); const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
const int num_heads_k = k.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
...@@ -575,8 +589,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -575,8 +589,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
if (window_size_right >= max_seqlen_k) { window_size_right = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size_og); CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); if (!paged_KV) {
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}
CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){ if (seqused_k.has_value()){
...@@ -654,6 +676,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -654,6 +676,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
window_size_left, window_size_left,
window_size_right, window_size_right,
seqlenq_ngroups_swapped); seqlenq_ngroups_swapped);
if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.k_batch_stride = k_padded.stride(0);
params.v_batch_stride = v_padded.stride(0);
}
params.page_block_size = page_block_size;
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding // Only apply split-k for decoding
set_params_splitkv(params, batch_size, num_heads, set_params_splitkv(params, batch_size, num_heads,
...@@ -682,7 +712,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -682,7 +712,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
if (max_seqlen_k > 0) { if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream, paged_KV);
} else { } else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_(); out.zero_();
......
...@@ -79,6 +79,7 @@ def _flash_attn_varlen_forward( ...@@ -79,6 +79,7 @@ def _flash_attn_varlen_forward(
window_size, window_size,
alibi_slopes, alibi_slopes,
return_softmax, return_softmax,
block_table,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
...@@ -90,6 +91,7 @@ def _flash_attn_varlen_forward( ...@@ -90,6 +91,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None, None,
block_table,
alibi_slopes, alibi_slopes,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
...@@ -299,6 +301,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -299,6 +301,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
...@@ -440,6 +443,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -440,6 +443,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None,
) )
ctx.save_for_backward( ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
...@@ -570,6 +574,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -570,6 +574,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
block_table,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -587,6 +592,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -587,6 +592,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
) )
ctx.save_for_backward( ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
...@@ -630,7 +636,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -630,7 +636,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
...@@ -1001,6 +1007,7 @@ def flash_attn_varlen_func( ...@@ -1001,6 +1007,7 @@ def flash_attn_varlen_func(
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
block_table=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
...@@ -1071,6 +1078,7 @@ def flash_attn_varlen_func( ...@@ -1071,6 +1078,7 @@ def flash_attn_varlen_func(
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
block_table,
) )
......
...@@ -1542,8 +1542,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1542,8 +1542,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
(1023, 1024), (1023, 1024),
], ],
) )
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): def test_flash_attn_varlen_causal(
seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
...@@ -1559,8 +1563,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1559,8 +1563,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) if paged_kv_block_size is None:
k = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
block_table = None
else:
k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
( (
...@@ -1580,8 +1595,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1580,8 +1595,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad = flash_attn_varlen_func( out_unpad = flash_attn_varlen_func(
q_unpad, q_unpad,
k_unpad, k_unpad if paged_kv_block_size is None else k_cache_paged,
v_unpad, v_unpad if paged_kv_block_size is None else v_cache_paged,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
...@@ -1589,6 +1604,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1589,6 +1604,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
0.0, 0.0,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
block_table=block_table,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
...@@ -1625,7 +1641,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1625,7 +1641,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
if test_backward:
( (
dq_unpad, dq_unpad,
dk_unpad, dk_unpad,
...@@ -1661,7 +1678,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1661,7 +1678,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
# of a Pytorch implementation. # of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): if test_backward:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
...@@ -1888,29 +1905,16 @@ def test_flash_attn_kvcache( ...@@ -1888,29 +1905,16 @@ def test_flash_attn_kvcache(
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
block_table = None block_table = None
else: else:
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 (
k_cache_paged = torch.randn( k_cache,
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype v_cache,
) block_table,
v_cache_paged = torch.randn( k_cache_paged,
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype v_cache_paged,
) num_blocks,
block_table = rearrange( ) = _generate_block_kvcache(
torch.randperm(num_blocks, dtype=torch.int32, device=device), seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
"(b nblocks) -> b nblocks",
b=batch_size,
) )
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
cache_seqlens = torch.randint( cache_seqlens = torch.randint(
0 if new_kv else 1, 0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
...@@ -2073,6 +2077,33 @@ def test_flash_attn_kvcache( ...@@ -2073,6 +2077,33 @@ def test_flash_attn_kvcache(
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
k_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
block_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment