Unverified Commit 5e4e0b2c authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Add sink attention support from cuDNN (#2148)



* first draft; debug plan failure
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* debug uid error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak params
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add grad in output
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix prints in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* address review comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix unfused grad; add softmax_type; add sink to bwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix padding mask; add swa tests; remove requires_grad for off-by-one
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix indent
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix non-determinism and shapes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up prints
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add GQA
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add CP A2A; dq/dk mismatches
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; need cleaner solution
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CP A2A; pending cudnn kernel change
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix world size in unit test; avoid thd format
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix thd logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fp8 context
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak CP logging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* allow no_mask/padding for SWA(left,0)
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Revert "allow no_mask/padding for SWA(left,0)"

This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca.
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add softmax_type to Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add cuDNN version control
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* prettify tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* skip 9.13 for MLA, non 192/128
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename compare_with_error
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* small cleanups and improvements
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix minor CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force sink/dsink to be float32
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch FE to GH FE
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* return to GH TE main FE commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FE to 1.14.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up before CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* bump up cudnn version
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add backend selection guard for unit tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add docstring for softmax type enums in C
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarChen Cui <chcui@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 57b4d7bc
...@@ -73,28 +73,31 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -73,28 +73,31 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const std::optional<at::Tensor> cu_seqlens_q_padded, const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q, NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......
...@@ -58,13 +58,14 @@ namespace transformer_engine::pytorch { ...@@ -58,13 +58,14 @@ namespace transformer_engine::pytorch {
// get the fused attention backend // get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend; return fused_attention_backend;
} }
...@@ -72,14 +73,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -72,14 +73,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const std::optional<at::Tensor> cu_seqlens_q_padded, const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
TensorWrapper te_Q, te_K, te_V, te_O, te_S; TensorWrapper te_Q, te_K, te_V, te_O, te_S;
auto none = py::none(); auto none = py::none();
...@@ -181,6 +183,16 @@ std::vector<py::object> fused_attn_fwd( ...@@ -181,6 +183,16 @@ std::vector<py::object> fused_attn_fwd(
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
} }
// softmax offset
TensorWrapper te_SoftmaxOffset;
if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) {
auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec();
std::vector<size_t> SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()};
te_SoftmaxOffset =
makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape,
DType::kFloat32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset // extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -199,11 +211,11 @@ std::vector<py::object> fused_attn_fwd( ...@@ -199,11 +211,11 @@ std::vector<py::object> fused_attn_fwd(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
...@@ -215,51 +227,52 @@ std::vector<py::object> fused_attn_fwd( ...@@ -215,51 +227,52 @@ std::vector<py::object> fused_attn_fwd(
// output_tensors = [O, nvte_aux_tensor_pack.tensors] // output_tensors = [O, nvte_aux_tensor_pack.tensors]
std::vector<py::object> output_tensors; std::vector<py::object> output_tensors;
output_tensors.push_back(o_python); output_tensors.push_back(o_python);
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) {
output_tensors.push_back(py::cast(output_tensor));
NVTEBasicTensor temp_data = {output_tensor.data_ptr(),
nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]),
nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
};
// allocate memory for nvte_aux_tensor_pack.tensors // allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0;
at::Tensor output_tensor; at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) { // intermediate softmax tensor, S or M
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { output_tensor =
if (i < nvte_aux_tensor_pack.size - 2) { allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
} else if (i == nvte_aux_tensor_pack.size - 2) { set_tensor_param(i++, output_tensor);
output_tensor = rng_state; // fp8 has an additional softmax stats tensor, ZInv
} else if (i == nvte_aux_tensor_pack.size - 1) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor = Bias.value();
}
} else {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = output_tensor =
(i < nvte_aux_tensor_pack.size - 1) allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
? allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false)
: rng_state;
}
} else {
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
} }
output_tensors.push_back(py::cast(output_tensor)); // rng_state
NVTEBasicTensor temp_data = {output_tensor.data_ptr(), if (i < nvte_aux_tensor_pack.size) {
nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), set_tensor_param(i++, rng_state);
nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; }
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); // bias (optional)
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
set_tensor_param(i++, Bias.value());
}
// softmax_offset (optional)
if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) {
set_tensor_param(i++, SoftmaxOffset.value());
} }
// execute the kernel // execute the kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
...@@ -274,9 +287,10 @@ std::vector<py::object> fused_attn_fwd( ...@@ -274,9 +287,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q, NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
...@@ -499,6 +513,15 @@ std::vector<py::object> fused_attn_bwd( ...@@ -499,6 +513,15 @@ std::vector<py::object> fused_attn_bwd(
} }
} }
// create dSoftmaxOffset in the same shape as SoftmaxOffset
at::Tensor dSoftmaxOffset;
TensorWrapper te_dSoftmaxOffset;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA);
dSoftmaxOffset = torch::empty({1, static_cast<int64_t>(h_q), 1, 1}, options);
te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset);
}
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
...@@ -507,10 +530,10 @@ std::vector<py::object> fused_attn_bwd( ...@@ -507,10 +530,10 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd( nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// allocate memory for workspace // allocate memory for workspace
...@@ -523,16 +546,16 @@ std::vector<py::object> fused_attn_bwd( ...@@ -523,16 +546,16 @@ std::vector<py::object> fused_attn_bwd(
nvte_fused_attn_bwd( nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
workspace.data(), at::cuda::getCurrentCUDAStream()); window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {py_dQ, py_dK, py_dV, py::cast(dBias)}; return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)};
} }
at::Tensor fa_prepare_fwd(at::Tensor qkvi) { at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
......
...@@ -966,6 +966,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -966,6 +966,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return return
dtype = inp.dtype dtype = inp.dtype
if not self.allow_different_data_and_param_types:
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if param is not None: if param is not None:
assert dtype == param.dtype, ( assert dtype == param.dtype, (
...@@ -1060,6 +1061,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1060,6 +1061,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
inp: torch.Tensor, inp: torch.Tensor,
num_gemms: int = 1, num_gemms: int = 1,
allow_non_contiguous: bool = False, allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]: ) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD. """Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know The context manager is needed because there isn't a way for a module to know
...@@ -1067,6 +1069,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1067,6 +1069,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
to setup the forward aggregated amax reduction for every module to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one. just in case. The autocast exit will pick up the most recent one.
""" """
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True self.forwarded_at_least_once = True
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
......
...@@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module): ...@@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module):
and `DotProductAttention` modules. and `DotProductAttention` modules.
name: str, default = `None` name: str, default = `None`
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper:
`Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -306,6 +317,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -306,6 +317,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type: Optional[str] = None, qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6, qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False, qk_norm_before_rope: bool = False,
softmax_type: str = "vanilla",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -362,6 +374,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -362,6 +374,7 @@ class TransformerLayer(torch.nn.Module):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.attn_input_format = attn_input_format self.attn_input_format = attn_input_format
self.softmax_type = softmax_type
self.name = name self.name = name
...@@ -397,6 +410,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -397,6 +410,7 @@ class TransformerLayer(torch.nn.Module):
"qkv_format": self.attn_input_format, "qkv_format": self.attn_input_format,
"seq_length": seq_length, "seq_length": seq_length,
"micro_batch_size": micro_batch_size, "micro_batch_size": micro_batch_size,
"softmax_type": self.softmax_type,
} }
self.self_attention = MultiheadAttention( self.self_attention = MultiheadAttention(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment