Commit dbfe0051 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Minor fixes

Minor fixes

Minor fixes
parent 31257062
...@@ -226,14 +226,17 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -226,14 +226,17 @@ bool run_grouped_conv_fwd(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv =
InDataType, ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
WeiDataType, InDataType,
OutDataType, WeiDataType,
InElementOp, OutDataType,
WeiElementOp, InElementOp,
OutElementOp, WeiElementOp,
NumDs>(); OutElementOp,
0, /*Num A Elementwise Tensors*/
0, /*Num B Elementwise Tensors*/
NumDs>();
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in, auto ref_argument = ref_conv.MakeArgument(in,
...@@ -246,6 +249,8 @@ bool run_grouped_conv_fwd(bool do_verification, ...@@ -246,6 +249,8 @@ bool run_grouped_conv_fwd(bool do_verification,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
{},
{},
d_tensors); d_tensors);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -104,7 +104,7 @@ __global__ void ...@@ -104,7 +104,7 @@ __global__ void
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -121,13 +121,13 @@ __global__ void ...@@ -121,13 +121,13 @@ __global__ void
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp; BsPointer p_bs_grid_grp;
const auto as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}( static_for<0, NumATensor, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
const auto bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}( static_for<0, NumBTensor, 1>{}(
...@@ -988,7 +988,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -988,7 +988,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
static auto MakeArgument( static auto MakeArgument(
APointers p_as, APointers p_as,
BPointers p_bs, BPointers p_bs,
std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
......
...@@ -20,17 +20,12 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -20,17 +20,12 @@ struct ComputePtrOffsetOfStridedBatch
index_t BatchStrideB, index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) index_t BatchStrideE)
: BatchStrideA_(), : BatchStrideA_(BatchStrideA),
BatchStrideB_(), BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE) BatchStrideE_(BatchStrideE)
{ {
if constexpr(!isMultiAB) if constexpr(isMultiAB)
{
BatchStrideA_ = BatchStrideA;
BatchStrideB_ = BatchStrideB;
}
else
{ {
static_assert("Invalid constructor for multiple A or B"); static_assert("Invalid constructor for multiple A or B");
} }
...@@ -40,17 +35,12 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -40,17 +35,12 @@ struct ComputePtrOffsetOfStridedBatch
Array<ck::index_t, NumBTensor> BatchStrideBs, Array<ck::index_t, NumBTensor> BatchStrideBs,
Array<ck::index_t, NumDTensor> BatchStrideDs, Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE) index_t BatchStrideE)
: BatchStrideA_(), : BatchStrideA_(BatchStrideAs),
BatchStrideB_(), BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs), BatchStrideDs_(BatchStrideDs),
BatchStrideE_(BatchStrideE) BatchStrideE_(BatchStrideE)
{ {
if constexpr(isMultiAB) if constexpr(!isMultiAB)
{
BatchStrideA_ = BatchStrideAs;
BatchStrideB_ = BatchStrideBs;
}
else
{ {
static_assert("Invalid constructor for single A and B"); static_assert("Invalid constructor for single A and B");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment