Commit 8ea2e1c9 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Fix the IsSupported check in contraction op

parent 8ff845f2
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op}
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{}
{ {
// populate pointer, batch stride, desc for Ds // populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -448,18 +443,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -448,18 +443,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1];
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
} }
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ns_lengths, a_ms_ks_strides);
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
}
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(e_ms_ns_lengths, e_ms_ns_strides);
} }
void Print() const void Print() const
...@@ -499,15 +503,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -499,15 +503,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
index_t a_mz_stride_; bool a_mz_consecutive_;
index_t a_kz_stride_; bool a_kz_consecutive_;
index_t b_nz_stride_; bool b_nz_consecutive_;
index_t b_kz_stride_; bool b_kz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
index_t e_mz_stride_; bool e_nz_consecutive_;
index_t e_nz_stride_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -616,65 +624,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -616,65 +624,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!"); "wrong!");
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
{ {
if(!(arg.a_mz_stride_ == 1 && return false;
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of B: could be on N or BK1 dimension const bool valid_b_vector_size =
if constexpr(BBlockTransferSrcVectorDim == 1) arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
{ const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
if(!(arg.b_nz_stride_ == 1 && const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
{ if(!(valid_b_vector_size && valid_b_access_dim))
return false;
}
}
else
{ {
if(!(arg.b_kz_stride_ == 1 && return false;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of Ds: always on NPerBlock dimension bool valid_ds_access = true;
bool valid_d_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_d_access = false; valid_ds_access = false;
} }
}); });
if(valid_ds_access == false)
if(valid_d_access == false)
{ {
return false; return false;
} }
// vector memory access of E: always on NPerBlock dimension const bool valid_e_vector_size =
if(!(arg.e_nz_stride_ == 1 && arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % // Vector write of E is always on N dimension.
CDEBlockTransferScalarPerVector_NPerBlock == const bool valid_e_access_dim = arg.e_nz_consecutive_;
0)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template <index_t NumDim1, index_t NumDim2>
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
{
assert(lengths.size() == NumDim1 + NumDim2 && strides.size() == NumDim1 + NumDim2);
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
{
consecutive_stride *= lengths[dim_idx];
}
else
{
break;
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment