Commit 1696ca42 authored by letaoqin's avatar letaoqin
Browse files

batched bwd add check for bias

parent 21cec2bb
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -861,6 +861,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_strides[NumDimG + NumDimM]);
}
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
......@@ -974,6 +977,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
index_t m_raw_padded_;
index_t n_raw_padded_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
};
// Invoker
......@@ -1120,6 +1126,19 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false;
}
if constexpr(!is_same<D0DataType, void>::value)
{
if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
......
......@@ -874,6 +874,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_strides[NumDimG + NumDimM]);
}
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
......@@ -987,6 +990,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
index_t m_raw_padded_;
index_t n_raw_padded_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
};
// Invoker
......@@ -1148,6 +1154,18 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false;
}
if constexpr(!is_same<D0DataType, void>::value)
{
if(arg.d0_n_length_stride_[1] == 1 &&
arg.d0_n_length_stride_[0] % D0BlockTransferSrcScalarPerVector != 0)
{
return false;
}
if(arg.d0_n_length_stride_[1] != 1 && D0BlockTransferSrcScalarPerVector != 1)
{
return false;
}
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment