"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "89eb58bc6fdb58df1506f8c118b1923169f8726e"
Commit 8ea2e1c9 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Fix the IsSupported check in contraction op

parent 8ff845f2
......@@ -11,6 +11,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
......@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{}
cde_element_op_{cde_element_op}
{
// populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
......@@ -448,18 +443,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
// for sanity check of vector memory access
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1];
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1];
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1];
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1];
a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1];
ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
}
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1];
a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ns_lengths, a_ms_ks_strides);
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i)
{
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
}
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(e_ms_ns_lengths, e_ms_ns_strides);
}
void Print() const
......@@ -499,15 +503,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t a_mz_stride_;
index_t a_kz_stride_;
index_t b_nz_stride_;
index_t b_kz_stride_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_mz_stride_;
index_t e_nz_stride_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
bool a_mz_consecutive_;
bool a_kz_consecutive_;
bool b_nz_consecutive_;
bool b_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
};
// Invoker
......@@ -616,65 +624,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!");
// vector memory access of A: could be on M or AK1 dimension
if constexpr(ABlockTransferSrcVectorDim == 1)
const bool valid_a_vector_size =
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
{
if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
return false;
}
// vector memory access of B: could be on N or BK1 dimension
if constexpr(BBlockTransferSrcVectorDim == 1)
{
if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
const bool valid_b_vector_size =
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
{
if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
return false;
}
// vector memory access of Ds: always on NPerBlock dimension
bool valid_d_access = true;
bool valid_ds_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{
valid_d_access = false;
valid_ds_access = false;
}
});
if(valid_d_access == false)
if(valid_ds_access == false)
{
return false;
}
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template <index_t NumDim1, index_t NumDim2>
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
{
assert(lengths.size() == NumDim1 + NumDim2 && strides.size() == NumDim1 + NumDim2);
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
{
consecutive_stride *= lengths[dim_idx];
}
else
{
break;
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment