Commit 81b26528 authored by Chao Liu's avatar Chao Liu
Browse files

added bias add; worked around compiler issues

parent 4f2c8bce
...@@ -50,7 +50,7 @@ template <typename SrcData, ...@@ -50,7 +50,7 @@ template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SrcElementwiseOperation, typename DstElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
...@@ -72,9 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -72,9 +72,9 @@ struct ThreadwiseTensorSliceTransfer_v1r3
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3( __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(
const DstDesc& dst_desc, const DstDesc& dst_desc,
const Index& dst_slice_origin_idx, const Index& dst_slice_origin_idx,
const SrcElementwiseOperation src_element_op) const DstElementwiseOperation& dst_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
src_element_op_{src_element_op} dst_element_op_{dst_element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -201,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -201,7 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
// apply element-wise operation and type convert // apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}])); type_convert<DstData>(dst_element_op_(src_buf[Number<src_offset>{}]));
}); });
const bool is_dst_valid = const bool is_dst_valid =
...@@ -378,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -378,7 +378,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private: private:
DstCoord dst_coord_; DstCoord dst_coord_;
SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_;
}; // namespace ck }; // namespace ck
// Assume: // Assume:
......
...@@ -32,7 +32,7 @@ template <typename SrcData, ...@@ -32,7 +32,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc typename Dst0Desc, // this is really one of sources, but it has same shape as DstDesc
typename Dst1Desc, // this is really one of sources, but it has same shape as DstDesc typename Dst1Desc, // this is really one of sources, but it has same shape as DstDesc
typename SrcElementwiseOperation, typename DstElementwiseOperation,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t DstVectorDim, index_t DstVectorDim,
...@@ -60,11 +60,11 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -60,11 +60,11 @@ struct ThreadwiseTensorSliceTransfer_v1r4
const Dst0Desc& dst0_desc, const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc, const Dst1Desc& dst1_desc,
const Index& dst_slice_origin_idx, const Index& dst_slice_origin_idx,
const SrcElementwiseOperation src_element_op) const DstElementwiseOperation& dst_element_op)
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)), dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin_idx)),
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin_idx)), dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin_idx)),
src_element_op_{src_element_op} dst_element_op_{dst_element_op}
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
...@@ -258,15 +258,45 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -258,15 +258,45 @@ struct ThreadwiseTensorSliceTransfer_v1r4
using dst_vector_t = using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type; typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
// copy data from src_buf into dst_vector // load dst0 and dst1 and apply elementwise operation
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { {
constexpr index_t src_offset = src_desc.CalculateOffset( // WARNING!!!!!!: this logic is only correct if DstScalarPerVector=1
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); // TODO: fix this
static_assert(DstScalarPerVector == 1, "wrong!");
// apply element-wise operation and type convert
dst_vector.template AsType<DstData>()(i) = // copy data from src_buf into dst_vector_src_data
type_convert<DstData>(src_element_op_(src_buf[Number<src_offset>{}])); constexpr index_t src_offset =
}); src_desc.CalculateOffset(src_slice_origin_idx + dst_data_idx);
const SrcData src_v = src_buf[Number<src_offset>{}];
// load dst0 and dst1
const bool is_dst0_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst0_desc,
dst0_coord_);
const bool is_dst1_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst1_desc,
dst1_coord_);
const DstData dst0_v =
dst0_buf.template Get<DstData>(dst0_coord_.GetOffset(), is_dst0_valid);
const DstData dst1_v =
dst1_buf.template Get<DstData>(dst1_coord_.GetOffset(), is_dst1_valid);
#if !CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
// apply element-wise operation in SrcData type
const SrcData dst_v = dst_element_op_(
src_v, type_convert<SrcData>(dst0_v), type_convert<SrcData>(dst1_v));
// apply type convert
dst_vector.template AsType<DstData>()(Number<0>{}) = type_convert<DstData>(dst_v);
#else
// apply element-wise operation in DstData type
const DstData dst_v = dst_element_op_(src_v, dst0_v, dst1_v);
dst_vector.template AsType<DstData>()(Number<0>{}) = dst_v;
#endif
}
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
...@@ -327,11 +357,27 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -327,11 +357,27 @@ struct ThreadwiseTensorSliceTransfer_v1r4
{ {
move_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
// dst0
move_tensor_coordinate(
dst0_desc, dst0_coord_, dst0_forward_steps[dim_access_order[i]]);
// dst1
move_tensor_coordinate(
dst1_desc, dst1_coord_, dst1_forward_steps[dim_access_order[i]]);
} }
else else
{ {
move_tensor_coordinate( move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
// dst0
move_tensor_coordinate(
dst0_desc, dst0_coord_, dst0_backward_steps[dim_access_order[i]]);
// dst1
move_tensor_coordinate(
dst1_desc, dst1_coord_, dst1_backward_steps[dim_access_order[i]]);
} }
} }
}); });
...@@ -469,7 +515,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4 ...@@ -469,7 +515,7 @@ struct ThreadwiseTensorSliceTransfer_v1r4
DstCoord dst_coord_; DstCoord dst_coord_;
Dst0Coord dst0_coord_; Dst0Coord dst0_coord_;
Dst1Coord dst1_coord_; Dst1Coord dst1_coord_;
SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_;
}; // namespace ck }; // namespace ck
} // namespace ck } // namespace ck
......
...@@ -810,7 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -810,7 +810,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
SrcCoord src_coord_; SrcCoord src_coord_;
DstCoord dst_coord_; DstCoord dst_coord_;
SrcElementwiseOperation src_element_op_; const SrcElementwiseOperation src_element_op_;
}; };
} // namespace ck } // namespace ck
......
...@@ -136,6 +136,11 @@ ...@@ -136,6 +136,11 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
// workaround for register spill due to compiler issue, when casting type between fp32 and fp16
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R4_TYPE_CONVERT_ISSUE 1
#endif
namespace ck { namespace ck {
enum InMemoryDataOperationEnum_t enum InMemoryDataOperationEnum_t
......
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp" #include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp"
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
struct PassThrough struct PassThrough
{ {
template <typename T> template <typename T>
...@@ -27,17 +23,60 @@ struct PassThrough ...@@ -27,17 +23,60 @@ struct PassThrough
} }
}; };
struct Relu // GEMM Bias Add:
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
// C1 presents in memory as 1d vector, but is represented as 2D matrix C1[m, n], with stride = 0 in
// the "n" dimension
//
// alpha * v0 + beta * v1 + gamma * v2
// v0 is from C matrix
// v1 is from residual matrix
// v2 is from bias vector
struct BiasAdd
{ {
float alpha = 0.1; #if 1
// correct result
// no scratch memory, good VGPR allocation (59)
// good perf (101Tflops)
template <typename T1, typename T2>
__host__ __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const
{
// compiler seems very volatile to the order of these calculation:
// compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register
// over-allocation. Therefore, move v0 calculation to the very end
float a = T1(0.2) * v1 + T2(0.3) * v2;
float b = a + float(0.1) * v0;
// ReLU return b;
template <typename T> }
__host__ __device__ constexpr T operator()(T v) const #elif 0
// correct result
// some scratch memory (68), large VGPR usage (126)
// very little perf drop (101Tflops)
__host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const
{
return float(0.1) * v0 + ck::half_t(0.2) * v1 + ck::half_t(0.3) * v2;
}
#elif 0
// correct result
// some scratch memory (68 dword)
// some perf drop (94Tflops)
// fp64 instructions are used
__host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const
{
return 0.1 * v0 + 0.2 * v1 + 0.3 * v2;
}
#elif 1
// wrong result
// lots of scratch memory
// huge perf drop
__host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const
{ {
T tmp = alpha * v; return float(0.1) * v0 + float(0.2) * v1 + float(0.3) * v2;
return tmp > 0 ? tmp : 0;
} }
#endif
}; };
template <typename ADataType, template <typename ADataType,
...@@ -125,13 +164,49 @@ struct DeviceGemmInstance<float, ...@@ -125,13 +164,49 @@ struct DeviceGemmInstance<float,
// clang-format on // clang-format on
}; };
template <typename AType,
typename BType,
typename CType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
static void host_verify(const Tensor<AType>& a_m_k,
const Tensor<BType>& b_k_n,
Tensor<CType>& c_m_n,
const Tensor<CType>& c0_m_n,
const Tensor<CType>& c1_m_n,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = a_m_k.mDesc.GetLengths()[1];
double v = 0;
for(int k = 0; k < K; ++k)
{
v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
static_cast<const double>(b_element_op(b_k_n(k, n)));
}
c_m_n(m, n) = c_element_op(
v, static_cast<const double>(c0_m_n(m, n)), static_cast<const double>(c1_m_n(m, n)));
};
make_ParallelTensorFunctor(f_mk_kn_mn,
c_m_n.mDesc.GetLengths()[0],
c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
if(argc != 4) if(argc != 10)
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: run kernel # of times (>1)\n"); printf("arg3: run kernel # of times (>1)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0); exit(0);
} }
...@@ -140,18 +215,24 @@ int main(int argc, char* argv[]) ...@@ -140,18 +215,24 @@ int main(int argc, char* argv[])
const int nrepeat = std::stoi(argv[3]); const int nrepeat = std::stoi(argv[3]);
// GEMM shape // GEMM shape
ck::index_t M = 3840; ck::index_t M = std::stoi(argv[4]);
ck::index_t N = 4096; ck::index_t N = std::stoi(argv[5]);
ck::index_t K = 4096; ck::index_t K = std::stoi(argv[6]);
ck::index_t StrideA = 4096; ck::index_t StrideA = std::stoi(argv[7]);
ck::index_t StrideB = 4096; ck::index_t StrideB = std::stoi(argv[8]);
ck::index_t StrideC = 4096; ck::index_t StrideC = std::stoi(argv[9]);
// matrix data type // matrix data type
#if 1
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
#else
using ADataType = float;
using BDataType = float;
using CDataType = float;
#endif
// matrix layout // matrix layout
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
...@@ -219,6 +300,8 @@ int main(int argc, char* argv[]) ...@@ -219,6 +300,8 @@ int main(int argc, char* argv[])
c0_m_n_device_buf.ToDevice(c0_m_n.mData.data()); c0_m_n_device_buf.ToDevice(c0_m_n.mData.data());
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
auto c_element_op = BiasAdd{};
// do GEMM // do GEMM
auto gemm = typename DeviceGemmInstance<ADataType, auto gemm = typename DeviceGemmInstance<ADataType,
BDataType, BDataType,
...@@ -228,7 +311,7 @@ int main(int argc, char* argv[]) ...@@ -228,7 +311,7 @@ int main(int argc, char* argv[])
CLayout, CLayout,
PassThrough, PassThrough,
PassThrough, PassThrough,
Relu>::type{}; decltype(c_element_op)>::type{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
...@@ -244,7 +327,7 @@ int main(int argc, char* argv[]) ...@@ -244,7 +327,7 @@ int main(int argc, char* argv[])
StrideC, StrideC,
PassThrough{}, PassThrough{},
PassThrough{}, PassThrough{},
Relu{}); c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -270,8 +353,13 @@ int main(int argc, char* argv[]) ...@@ -270,8 +353,13 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
host_gemm_mk_kn_mn(a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, Relu{}); host_verify(a_m_k,
b_k_n,
check_error(c_m_n_host_result, c_m_n_device_result); c_m_n_host_result,
c0_m_n,
c1_m_n,
PassThrough{},
PassThrough{},
c_element_op);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment