Commit b2df7018 authored by letaoqin's avatar letaoqin
Browse files

fix ComputeBasePtrOfStridedBatch init bug

parent 4b0a5069
......@@ -183,8 +183,8 @@ int run(int argc, char* argv[])
// TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
......@@ -213,8 +213,8 @@ int run(int argc, char* argv[])
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number of
// elements on a thread
{seed, offset}); // dropout random seed and offset, offset should be at
// least the number of elements on a thread
if(!gemm.IsSupportedArgument(argument))
{
......@@ -230,7 +230,9 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * std::is_void<DDataType>::value?1:0) *
sizeof(DDataType) * M * N * std::is_void<DDataType>::value
? 0
: 1) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -243,8 +245,8 @@ int run(int argc, char* argv[])
if(do_verification)
{
// run for storing z tensor
argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
argument =
gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
......@@ -273,8 +275,8 @@ int run(int argc, char* argv[])
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
......
......@@ -697,6 +697,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment