Unverified Commit 933951ed authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Fix continous dim selection in contraction (#1336)

* Fix continous dim selection in contraction

* Fixes
parent 17ed368f
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access // for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i) for(index_t i = 0; i < NumATensor; ++i)
{ {
as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1; tie(as_continous_dim_[i], as_max_read_elems_[i]) =
as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]); CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
} }
for(index_t i = 0; i < NumBTensor; ++i) for(index_t i = 0; i < NumBTensor; ++i)
{ {
bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1; tie(bs_continous_dim_[i], bs_max_read_elems_[i]) =
bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]); CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
} }
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]); CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
} }
e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1; tie(e_continous_dim_, e_max_write_elems_) =
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride); CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
} }
// pointers // pointers
...@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive // Describe whether the last part of a given dimension of A/B/D/E is continues dim.
// in the memory or not. std::array<index_t, NumATensor> as_continous_dim_;
std::array<bool, NumATensor> as_mz_consecutive_; std::array<index_t, NumATensor> bs_continous_dim_;
std::array<bool, NumATensor> as_kz_consecutive_; std::array<index_t, NumBTensor> ds_continous_dim_;
std::array<bool, NumBTensor> bs_nz_consecutive_; index_t e_continous_dim_;
std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
std::array<index_t, NumATensor> as_max_read_elems_; std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_; std::array<index_t, NumBTensor> bs_max_read_elems_;
...@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_a_vector_size = const bool valid_a_vector_size =
arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = const bool valid_a_access_dim_m =
ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i]; ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0;
const bool valid_a_access_dim_k = const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!((valid_a_vector_size && valid_a_access_dim) || if(!((valid_a_vector_size && valid_a_access_dim) ||
ABlockTransferSrcScalarPerVector == 1)) ABlockTransferSrcScalarPerVector == 1))
...@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_b_vector_size = const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i]; BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0;
const bool valid_b_access_dim_k = const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1;
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!((valid_b_vector_size && valid_b_access_dim) || if(!((valid_b_vector_size && valid_b_access_dim) ||
BBlockTransferSrcScalarPerVector == 1)) BBlockTransferSrcScalarPerVector == 1))
...@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_d_vector_size = const bool valid_d_vector_size =
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension. // Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1;
if(!((valid_d_vector_size && valid_d_access_dim) || if(!((valid_d_vector_size && valid_d_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1)) CDEBlockTransferScalarPerVector_NPerBlock == 1))
{ {
...@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_e_vector_size = const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension. // Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_; const bool valid_e_access_dim = arg.e_continous_dim_ == 1;
if(!((valid_e_vector_size && valid_e_access_dim) || if(!((valid_e_vector_size && valid_e_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1)) CDEBlockTransferScalarPerVector_NPerBlock == 1))
{ {
......
...@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1; tie(a_continous_dim_, a_max_read_elems_) =
a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides); CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1; tie(b_continous_dim_, b_max_read_elems_) =
b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides); CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; tie(ds_continous_dim_[i], ds_max_read_elems_[i]) =
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
} }
e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1; tie(e_continous_dim_, e_max_write_elems_) =
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides); CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
} }
...@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive // Describe whether the last part of a given dimension of A/B/D/E is continues dim.
// in the memory or not. index_t a_continous_dim_;
bool a_mz_consecutive_; index_t b_continous_dim_;
bool a_kz_consecutive_; std::array<index_t, NumDTensor> ds_continous_dim_;
bool b_nz_consecutive_; index_t e_continous_dim_;
bool b_kz_consecutive_;
std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
index_t a_max_read_elems_; index_t a_max_read_elems_;
index_t b_max_read_elems_; index_t b_max_read_elems_;
...@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -624,8 +615,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_a_vector_size = const bool valid_a_vector_size =
arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; const bool valid_a_access_dim_m =
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0;
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1;
const bool valid_a_access_dim = const bool valid_a_access_dim =
valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
if(!(valid_a_vector_size && valid_a_access_dim)) if(!(valid_a_vector_size && valid_a_access_dim))
...@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -635,8 +628,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const bool valid_b_vector_size = const bool valid_b_vector_size =
arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; const bool valid_b_access_dim_n =
const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0;
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1;
const bool valid_b_access_dim = const bool valid_b_access_dim =
valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
if(!(valid_b_vector_size && valid_b_access_dim)) if(!(valid_b_vector_size && valid_b_access_dim))
...@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -650,7 +645,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension. // Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = const bool valid_d_access_dim =
arg.ds_nz_consecutive_[i] || CDEBlockTransferScalarPerVector_NPerBlock == 1; arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_d_vector_size && valid_d_access_dim)) if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_ds_access = false; valid_ds_access = false;
...@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -665,7 +660,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension. // Vector write of E is always on N dimension.
const bool valid_e_access_dim = const bool valid_e_access_dim =
arg.e_nz_consecutive_ || CDEBlockTransferScalarPerVector_NPerBlock == 1; arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
if(!(valid_e_vector_size && valid_e_access_dim)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind ...@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
} }
// Determine the beginning and end idx of the group representing the FCD. // Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx; index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
if(strides[NumDim1 - 1] == 1) if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
{ {
begin_idx = 0; // MZ or KZ are ones
end_idx = NumDim1 - 1; bool dims1_are_ones = true;
for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
{
if(lengths[dim_idx] != 1)
{
dims1_are_ones = false;
}
}
if(dims1_are_ones)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
}
else
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
}
}
else if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
continous_dim = 0;
} }
else if(strides[NumDim1 + NumDim2 - 1] == 1) else if(strides[NumDim1 + NumDim2 - 1] == 1)
{ {
begin_idx = NumDim1; begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1; end_idx = NumDim1 + NumDim2 - 1;
continous_dim = 1;
} }
else else
{ {
// The dimension consecutive in memory is not the last dimension of any group, so only // The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once. // one element can be read/written at once.
return 1; consecutive_stride = 1;
continous_dim = 0;
return make_tuple(continous_dim, consecutive_stride);
} }
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx) for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{ {
if(strides[dim_idx] == consecutive_stride) if(strides[dim_idx] == consecutive_stride)
...@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind ...@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
} }
} }
const index_t max_subsequent_elems = consecutive_stride; const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems; return make_tuple(continous_dim, max_subsequent_elems);
} }
} // namespace device } // namespace device
......
...@@ -212,4 +212,10 @@ TYPED_TEST(TestContractionScaleMixedPrecision, scale) ...@@ -212,4 +212,10 @@ TYPED_TEST(TestContractionScaleMixedPrecision, scale)
this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}});
this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); this->template Run<2>({{16, 8}, {16, 8}, {16, 8}});
this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); this->template Run<2>({{8, 16}, {16, 8}, {8, 16}});
// special cases
this->template Run<2>({{1, 1}, {16, 8}, {8, 16}});
this->template Run<2>({{8, 16}, {16, 8}, {1, 1}});
this->template Run<2>({{8, 16}, {1, 1}, {8, 16}});
this->template Run<2>({{1, 1}, {1, 1}, {1, 1}});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment