Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
...@@ -13,78 +13,94 @@ using cutlass::gemm::GemmShape; ...@@ -13,78 +13,94 @@ using cutlass::gemm::GemmShape;
// Add 128 bits padding when the last dim is a multiple of 256 bits // Add 128 bits padding when the last dim is a multiple of 256 bits
template <typename T, bool transpose, int M, int K, typename Enable = void> template <typename T, bool transpose, int M, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutA { struct DispatchSharedMemoryLayoutA {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor, using Layout =
cutlass::layout::RowMajor>::type; typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? M : K; static int constexpr Dim = transpose ? M : K;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; static int constexpr Stride =
(Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
}; };
template <typename T, bool transpose, int N, int K, typename Enable = void> template <typename T, bool transpose, int N, int K, typename Enable = void>
struct DispatchSharedMemoryLayoutB { struct DispatchSharedMemoryLayoutB {
using Layout = typename std::conditional<transpose, cutlass::layout::ColumnMajor, using Layout =
cutlass::layout::RowMajor>::type; typename std::conditional<transpose, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>::type;
static int constexpr Dim = transpose ? K : N; static int constexpr Dim = transpose ? K : N;
static int constexpr Stride = (Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim; static int constexpr Stride =
(Dim * sizeof(T) % 32 == 0) ? Dim + 16 / sizeof(T) : Dim;
}; };
// Partial specialization for half_t // Partial specialization for half_t
template <int M, int K> template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, true, M, K, typename std::enable_if<M % 64 == 0>::type> { struct DispatchSharedMemoryLayoutA<half_t, true, M, K,
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>; typename std::enable_if<M % 64 == 0>::type> {
using Layout =
cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCongruous<16>;
static int constexpr Stride = M; static int constexpr Stride = M;
}; };
template <int M, int K> template <int M, int K>
struct DispatchSharedMemoryLayoutA<half_t, false, M, K> { struct DispatchSharedMemoryLayoutA<half_t, false, M, K> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>; using Layout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = M; static int constexpr Stride = M;
}; };
template <int N, int K> template <int N, int K> struct DispatchSharedMemoryLayoutB<half_t, true, N, K> {
struct DispatchSharedMemoryLayoutB<half_t, true, N, K> { using Layout =
using Layout = cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>; cutlass::layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise<16, K>;
static int constexpr Stride = N; static int constexpr Stride = N;
}; };
template <int N, int K> template <int N, int K>
struct DispatchSharedMemoryLayoutB<half_t, false, N, K, struct DispatchSharedMemoryLayoutB<half_t, false, N, K,
typename std::enable_if<N % 64 == 0>::type> { typename std::enable_if<N % 64 == 0>::type> {
using Layout = cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>; using Layout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandBCongruous<16>;
static int constexpr Stride = N; static int constexpr Stride = N;
}; };
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type_raw, typename B_type_raw, typename C_type_raw> bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp { class GemmTensorOp {
public: public:
using A_type = A_type_raw; using A_type = A_type_raw;
using B_type = B_type_raw; using B_type = B_type_raw;
using C_type = C_type_raw; using C_type = C_type_raw;
using InstructionShape = GemmShape<16, 16, 4>; using InstructionShape = GemmShape<16, 16, 4>;
using SMemLayoutA = using SMemLayoutA =
typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Layout; typename DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
Shape::kK>::Layout;
using SMemLayoutB = using SMemLayoutB =
typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Layout; typename DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
Shape::kK>::Layout;
static constexpr int stride_A = static constexpr int stride_A =
DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM, Shape::kK>::Stride; DispatchSharedMemoryLayoutA<A_type, trans_A, Shape::kM,
Shape::kK>::Stride;
static constexpr int stride_B = static constexpr int stride_B =
DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN, Shape::kK>::Stride; DispatchSharedMemoryLayoutB<B_type, trans_B, Shape::kN,
Shape::kK>::Stride;
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape, 32, A_type, cutlass::arch::Mma<
typename std::conditional<trans_A, cutlass::layout::ColumnMajor, InstructionShape, 32, A_type,
cutlass::layout::RowMajor>::type, typename std::conditional<trans_A, cutlass::layout::ColumnMajor,
B_type, cutlass::layout::RowMajor>::type,
typename std::conditional<trans_B, cutlass::layout::ColumnMajor, B_type,
cutlass::layout::RowMajor>::type, typename std::conditional<trans_B, cutlass::layout::ColumnMajor,
C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::layout::RowMajor>::type,
cutlass::MatrixShape<1, 1> >; C_type, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;
static_assert(Shape::kM % num_warp_m == 0); static_assert(Shape::kM % num_warp_m == 0);
static_assert(Shape::kN % num_warp_n == 0); static_assert(Shape::kN % num_warp_n == 0);
using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp< using MmaWarp = typename cutlass::gemm::warp::MmaVoltaTensorOp<
GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n, InstructionShape::kK>, A_type, GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n,
SMemLayoutA, B_type, SMemLayoutB, C_type, cutlass::layout::RowMajor, Policy>; InstructionShape::kK>,
A_type, SMemLayoutA, B_type, SMemLayoutB, C_type,
cutlass::layout::RowMajor, Policy>;
using TensorRefA = typename MmaWarp::IteratorA::TensorRef; using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef; using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
...@@ -97,13 +113,14 @@ class GemmTensorOp { ...@@ -97,13 +113,14 @@ class GemmTensorOp {
static_assert(Shape::kK % InstructionShape::kK == 0); static_assert(Shape::kK % InstructionShape::kK == 0);
static int constexpr kKgroups = Shape::kK / InstructionShape::kK; static int constexpr kKgroups = Shape::kK / InstructionShape::kK;
static CUTLASS_DEVICE void body(A_type_raw* pA, B_type_raw* pB, FragmentC& accum, static CUTLASS_DEVICE void body(A_type_raw *pA, B_type_raw *pB,
const int warp_idx_m, const int warp_idx_n, const int lane_id) { FragmentC &accum, const int warp_idx_m,
const int warp_idx_n, const int lane_id) {
MmaWarp mma_op; MmaWarp mma_op;
FragmentA frag_A; FragmentA frag_A;
FragmentB frag_B; FragmentB frag_B;
const TensorRefA ref_A((A_type*)pA, stride_A); const TensorRefA ref_A((A_type *)pA, stride_A);
const TensorRefB ref_B((B_type*)pB, stride_B); const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorA iter_A(ref_A, lane_id); IteratorA iter_A(ref_A, lane_id);
IteratorB iter_B(ref_B, lane_id); IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0}); iter_A.add_tile_offset({warp_idx_m, 0});
...@@ -118,11 +135,12 @@ class GemmTensorOp { ...@@ -118,11 +135,12 @@ class GemmTensorOp {
} }
} }
static CUTLASS_DEVICE void body_rs(const FragmentA* frag_A, B_type_raw* pB, FragmentC& accum, static CUTLASS_DEVICE void body_rs(const FragmentA *frag_A, B_type_raw *pB,
const int warp_idx_n, const int lane_id) { FragmentC &accum, const int warp_idx_n,
const int lane_id) {
MmaWarp mma_op; MmaWarp mma_op;
FragmentB frag_B; FragmentB frag_B;
const TensorRefB ref_B((B_type*)pB, stride_B); const TensorRefB ref_B((B_type *)pB, stride_B);
IteratorB iter_B(ref_B, lane_id); IteratorB iter_B(ref_B, lane_id);
iter_B.add_tile_offset({0, warp_idx_n}); iter_B.add_tile_offset({0, warp_idx_n});
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
...@@ -136,27 +154,29 @@ class GemmTensorOp { ...@@ -136,27 +154,29 @@ class GemmTensorOp {
namespace tl { namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_ss(A_type* pA, B_type* pB, C_type* accum) { CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type, using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
B_type, C_type>; trans_B, A_type, B_type, C_type>;
using FragmentC = typename MMA::FragmentC; using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32; int lane_id = threadIdx.x % 32;
MMA::body(pA, pB, *(FragmentC*)(accum), warp_id / num_warp_n, warp_id % num_warp_n, lane_id); MMA::body(pA, pB, *(FragmentC *)(accum), warp_id / num_warp_n,
warp_id % num_warp_n, lane_id);
} }
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename A_type, typename B_type, typename C_type> bool trans_B, typename A_type, typename B_type, typename C_type>
CUTLASS_DEVICE void gemm_rs(A_type* pA, B_type* pB, C_type* accum) { CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A, trans_B, A_type, using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
B_type, C_type>; trans_B, A_type, B_type, C_type>;
using FragmentA = typename MMA::FragmentA; using FragmentA = typename MMA::FragmentA;
using FragmentC = typename MMA::FragmentC; using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32; int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32; int lane_id = threadIdx.x % 32;
MMA::body_rs((const FragmentA*)(pA), pB, *(FragmentC*)(accum), warp_id % num_warp_n, lane_id); MMA::body_rs((const FragmentA *)(pA), pB, *(FragmentC *)(accum),
warp_id % num_warp_n, lane_id);
} }
}; // namespace tl }; // namespace tl
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -7,34 +7,29 @@ ...@@ -7,34 +7,29 @@
namespace tl { namespace tl {
struct SumOp { struct SumOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return x + y; return x + y;
} }
}; };
struct MaxOp { struct MaxOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return cutlass::fast_max(x, y); return cutlass::fast_max(x, y);
} }
}; };
struct MinOp { struct MinOp {
template <typename T> template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
TL_DEVICE T operator()(T const& x, T const& y) {
return cutlass::fast_min(x, y); return cutlass::fast_min(x, y);
} }
}; };
template <class Reducer, int threads, int scale> template <class Reducer, int threads, int scale> struct AllReduce {
struct AllReduce { static_assert(threads == 1024 or threads == 512 or threads == 256 or
static_assert(threads == 1024 or threads == 512 or threads == 256 or threads == 128 or threads == 128 or threads == 64 or threads == 32 or
threads == 64 or threads == 32 or threads == 16 or threads == 8 or threads == 4 or threads == 16 or threads == 8 or threads == 4 or threads == 2);
threads == 2);
static_assert(threads % scale == 0); static_assert(threads % scale == 0);
template <typename T> template <typename T> static TL_DEVICE T run(T x, T *red_buf = nullptr) {
static TL_DEVICE T run(T x, T* red_buf = nullptr) {
constexpr int offset = threads / 2; constexpr int offset = threads / 2;
if constexpr (offset >= 32) { if constexpr (offset >= 32) {
__syncthreads(); __syncthreads();
...@@ -54,4 +49,4 @@ struct AllReduce { ...@@ -54,4 +49,4 @@ struct AllReduce {
} }
}; };
} // namespace tl } // namespace tl
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment