Commit b37cb71f authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung
Browse files

Enable bwd wrw

parent c5143bca
...@@ -7,98 +7,160 @@ ...@@ -7,98 +7,160 @@
namespace ck { namespace ck {
template <class Float, class Matrix> template <typename Float, class Matrix>
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread) __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
{ {
for(index_t i = 0; i < Matrix::NRow(); ++i) for(index_t i = 0; i < Matrix::NRow(); ++i)
{ {
for(index_t j = 0; j < Matrix::NCol(); ++j) for(index_t j = 0; j < Matrix::NCol(); ++j)
{ {
const index_t id = Matrix::GetOffsetFromMultiIndex(i, j); const index_t id = Matrix::CalculateOffset(i, j);
p_thread[id] = Float(0); p_thread[id] = Float(0);
} }
} }
} }
template <class Float, template <typename SrcMatrix,
class SrcMatrix, typename DstMatrix,
class DstMatrix, index_t NSliceRow,
index_t NRow, index_t NSliceCol,
index_t NCol, index_t DataPerAccess>
index_t DataPerRead> struct ThreadwiseMatrixSliceCopy
__device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src,
DstMatrix,
Float* __restrict__ p_dst,
Sequence<NRow, NCol>,
Number<DataPerRead>)
{ {
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0"); __device__ constexpr ThreadwiseMatrixSliceCopy()
{
static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 &&
DstMatrix::RowStride() % DataPerAccess == 0,
"wrong! wrong alignment");
static_assert(NSliceCol % DataPerAccess == 0,
"wrong! should be NSliceCol % DataPerAccess == 0");
}
constexpr auto src_mtx = SrcMatrix{}; template <typename Data>
constexpr auto dst_mtx = DstMatrix{}; __device__ static void Run(const Data* p_src, Data* p_dst)
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{ {
for(index_t j = 0; j < NCol; j += DataPerRead) using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
for(index_t i = 0; i < NSliceRow; ++i)
{ {
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j); for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j); {
const index_t src_index = SrcMatrix::CalculateOffset(i, j);
const index_t dst_index = DstMatrix::CalculateOffset(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) = *reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]); *reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
} }
} }
} };
template <class MatrixA, // C += transpose(A) * B
class MatrixB, // Element of matrix can be vectorized data
class MatrixC, template <typename MatrixA, typename MatrixB, typename MatrixC>
bool TransA, struct ThreadwiseGemmTransANormalBNormalC
bool TransB,
bool TransC,
class FloatA,
class FloatB,
class FloatC>
__device__ void threadwise_gemm(MatrixA,
integral_constant<bool, TransA>,
const FloatA* __restrict__ p_a_thread,
MatrixB,
integral_constant<bool, TransB>,
const FloatB* __restrict__ p_b_thread,
MatrixC,
integral_constant<bool, TransC>,
FloatC* __restrict__ p_c_thread)
{ {
static_if<TransA && (!TransB) && (!TransC)>{}([&](auto) { __device__ constexpr ThreadwiseGemmTransANormalBNormalC()
constexpr auto a_mtx = MatrixA{}; {
constexpr auto b_mtx = MatrixB{}; static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() &&
constexpr auto c_mtx = MatrixC{}; MatrixB::NCol() == MatrixC::NCol(),
"wrong!");
}
constexpr index_t M = c_mtx.NRow(); template <typename FloatA, typename FloatB, typename FloatC>
constexpr index_t N = c_mtx.NCol(); __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
constexpr index_t K = a_mtx.NRow(); // A is transposed {
constexpr index_t M = MatrixC::NRow();
constexpr index_t N = MatrixC::NCol();
constexpr index_t K = MatrixA::NRow(); // A is transposed
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
for(index_t i = 0; i < M; ++i) for(index_t m = 0; m < M; ++m)
{ {
for(index_t j = 0; j < N; ++j) for(index_t n = 0; n < N; ++n)
{ {
const index_t aindex = a_mtx.GetOffsetFromMultiIndex(k, i); // A is transposed const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j); const index_t bindex = MatrixB::CalculateOffset(k, n);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j); const index_t cindex = MatrixC::CalculateOffset(m, n);
p_c_thread[cindex] += math::inner_product_with_conversion<FloatC>{}( p_c[cindex] +=
p_a_thread[aindex], p_b_thread[bindex]); inner_product_with_conversion<FloatC>{}(p_a[aindex], p_b[bindex]);
} }
} }
} }
}).Else([&](auto fwd) { }
// not implemented
static_assert(fwd(false), "wrong! support for this config is not implemented"); #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
}); template <typename FloatA, typename FloatB, typename FloatC>
} __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
constexpr index_t M = MatrixC::NRow();
constexpr index_t N = MatrixC::NCol();
constexpr index_t K = MatrixA::NRow(); // A is transposed
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
for(index_t k = 0; k < K; ++k)
{
for(index_t m = 0; m < M; ++m)
{
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
static_if<N == 2>{}([&](auto) {
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
__outer_product_1x2(
p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]);
});
static_if<N == 4>{}([&](auto) {
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
const index_t bindex_2 = MatrixB::CalculateOffset(k, 2);
const index_t bindex_3 = MatrixB::CalculateOffset(k, 3);
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
const index_t cindex_2 = MatrixC::CalculateOffset(m, 2);
const index_t cindex_3 = MatrixC::CalculateOffset(m, 3);
__outer_product_1x4(p_a[aindex],
p_b[bindex_0],
p_b[bindex_1],
p_b[bindex_2],
p_b[bindex_3],
p_c[cindex_0],
p_c[cindex_1],
p_c[cindex_2],
p_c[cindex_3]);
});
}
}
}
#endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
static_if<has_amd_asm>{}([&](auto fwd) {
Run_amd_asm(p_a, p_b, fwd(p_c));
}).Else([&](auto) { Run_source(p_a, p_b, p_c); });
#else
Run_source(p_a, p_b, p_c);
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -4,12 +4,9 @@ ...@@ -4,12 +4,9 @@
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#include "bfloat16_dev.hpp"
#define CK_DEVICE_BACKEND_AMD 1 #define CK_DEVICE_BACKEND_AMD 1
#define CK_USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
namespace ck { namespace ck {
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment