Unverified Commit c036765b authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[JAX] Consolidate FFI and old descriptor implementation for fused attention. (#1295)



Consolidate FFI and old descriptor impleemntation for fused attention.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent ed1e85c4
...@@ -185,46 +185,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -185,46 +185,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
} }
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { static void FusedAttnForwardImpl(
const CustomCallFusedAttnDescriptor &descriptor = cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens,
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output,
auto qkv_layout = descriptor.qkv_layout; void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];
/* Output buffer from XLA */
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];
/* Descriptor */
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto is_training = descriptor.is_training;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
auto window_size_left = descriptor.window_size_left;
auto window_size_right = descriptor.window_size_right;
/* Input tensors */ /* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
...@@ -247,8 +218,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -247,8 +218,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q; num_segments = runtime_num_segments_q;
} }
cudaMemsetAsync(output, 0, auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream); cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
} }
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
...@@ -281,28 +252,25 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -281,28 +252,25 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
backend, softmax_aux); backend, softmax_aux);
/* cuDNN workspace */ /* cuDNN workspace */
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size}, auto workspace_tensor =
descriptor.wkspace_dtype); TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);
/* Call the underly NVTE API */ /* Call the underlying NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_seq_offsets_tensor.data(), rng_state_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor, q_max_seqlen, is_training, scaling_factor, dropout_probability,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, qkv_layout, bias_type, mask_type, window_size_left,
workspace_tensor.data(), stream); window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
...@@ -310,14 +278,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -310,14 +278,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k = buffers[1];
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v = buffers[2];
auto v_shape = k_shape; auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
...@@ -333,6 +298,37 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -333,6 +298,37 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];
/* Output buffer from XLA */
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];
FusedAttnForwardImpl(
stream, buffers[0], buffers[1], buffers[2], bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets,
k_seq_offsets, seed, output, softmax_aux, rng_state, workspace, descriptor.input_batch,
descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen,
descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
Error_Type FusedAttnForwardFFI( Error_Type FusedAttnForwardFFI(
cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf, cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Buffer_Type v_buf,
Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type bias_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
...@@ -344,147 +340,25 @@ Error_Type FusedAttnForwardFFI( ...@@ -344,147 +340,25 @@ Error_Type FusedAttnForwardFFI(
double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_, double dropout_probability_, int64_t bias_type_, int64_t mask_type_, int64_t qkv_layout_,
int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic, int64_t dtype_, int64_t wkspace_dtype_, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right) {
/* Descriptor data type conversion */
size_t input_batch = static_cast<size_t>(input_batch_);
size_t bias_batch = static_cast<size_t>(bias_batch_);
size_t q_max_seqlen = static_cast<size_t>(q_max_seqlen_);
size_t kv_max_seqlen = static_cast<size_t>(kv_max_seqlen_);
size_t attn_heads = static_cast<size_t>(attn_heads_);
size_t num_gqa_groups = static_cast<size_t>(num_gqa_groups_);
size_t bias_heads = static_cast<size_t>(bias_heads_);
size_t head_dim = static_cast<size_t>(head_dim_);
size_t max_segments_per_seq = static_cast<size_t>(max_segments_per_seq_);
size_t wkspace_size = static_cast<size_t>(wkspace_size_);
float scaling_factor = static_cast<float>(scaling_factor_);
float dropout_probability = static_cast<float>(dropout_probability_);
NVTE_Bias_Type bias_type = static_cast<NVTE_Bias_Type>(bias_type_);
NVTE_Mask_Type mask_type = static_cast<NVTE_Mask_Type>(mask_type_);
NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_); NVTE_QKV_Layout qkv_layout = static_cast<NVTE_QKV_Layout>(qkv_layout_);
DType dtype = static_cast<DType>(dtype_);
DType wkspace_dtype = static_cast<DType>(wkspace_dtype_);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */ FusedAttnForwardImpl(
/* q, k, v are parsed later for different qkv_layout */ stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
void *bias = bias_buf.untyped_data(); bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
void *q_cu_seqlens = q_cu_seqlens_buf.untyped_data(); is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
void *kv_cu_seqlens = kv_cu_seqlens_buf.untyped_data(); is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(),
void *q_seq_offsets = is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr; output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(),
void *k_seq_offsets = is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr; workspace_buf->untyped_data(), static_cast<size_t>(input_batch_),
void *seed = seed_buf.untyped_data(); static_cast<size_t>(bias_batch_), static_cast<size_t>(q_max_seqlen_),
static_cast<size_t>(kv_max_seqlen_), static_cast<size_t>(attn_heads_),
/* Output buffer from XLA */ static_cast<size_t>(num_gqa_groups_), static_cast<size_t>(bias_heads_),
void *output = output_buf->untyped_data(); static_cast<size_t>(head_dim_), static_cast<size_t>(max_segments_per_seq_),
void *softmax_aux = softmax_aux_buf->untyped_data(); static_cast<size_t>(wkspace_size_), static_cast<float>(scaling_factor_),
void *rng_state = rng_state_buf->untyped_data(); static_cast<float>(dropout_probability_), static_cast<NVTE_Bias_Type>(bias_type_),
void *workspace = workspace_buf->untyped_data(); static_cast<NVTE_Mask_Type>(mask_type_), static_cast<NVTE_QKV_Layout>(qkv_layout_),
static_cast<DType>(dtype_), static_cast<DType>(wkspace_dtype_), is_training, deterministic,
/* Input tensors */ window_size_left, window_size_right);
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux);
/* cuDNN workspace */
auto workspace_tensor =
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);
/* Call the underlying NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = q_buf.untyped_data();
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = q_buf.untyped_data();
auto kv = k_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = q_buf.untyped_data();
auto k = k_buf.untyped_data();
auto v = v_buf.untyped_data();
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
nvte_tensor_pack_destroy(&aux_output_tensors);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment