Commit 5d07483b authored by Tri Dao's avatar Tri Dao
Browse files

Refactor Gmem code to store q, k, v pointers separately

parent d3e64409
......@@ -56,12 +56,18 @@ void set_params(Fused_multihead_attention_fprop_params &params,
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.qkv_ptr = qkv_packed_d;
params.qkv_stride_in_elts = h * 3 * d;
params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);
params.q_ptr = qkv_packed_d;
params.k_ptr = qkv_packed_d + get_size_in_bytes(h * d, data_type);
params.v_ptr = qkv_packed_d + 2 * get_size_in_bytes(h * d, data_type);
params.q_row_stride_in_elts = 3 * h * d;
params.k_row_stride_in_elts = 3 * h * d;
params.v_row_stride_in_elts = 3 * h * d;
params.q_head_stride_in_elts = d;
params.k_head_stride_in_elts = d;
params.v_head_stride_in_elts = d;
params.o_ptr = o_packed_d;
params.o_stride_in_elts = h * d;
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
params.o_row_stride_in_elts = h * d;
params.o_head_stride_in_elts = d;
params.do_ptr = do_packed_d;
params.o_tmp_ptr = o_tmp_d;
......
......@@ -50,15 +50,21 @@ constexpr int D_DIM = 3;
struct Qkv_params {
// The QKV matrices.
void * __restrict__ qkv_ptr;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
// size_t qkv_stride_in_elts;
// size_t qkv_stride_in_bytes;
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t qkv_stride_in_elts;
uint32_t qkv_stride_in_bytes;
uint32_t q_row_stride_in_elts;
uint32_t k_row_stride_in_elts;
uint32_t v_row_stride_in_elts;
uint32_t q_head_stride_in_elts;
uint32_t k_head_stride_in_elts;
uint32_t v_head_stride_in_elts;
// The number of heads.
int h;
......@@ -71,17 +77,14 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
void * __restrict__ dqkv_ptr;
// Temporary for dKV.
void * __restrict__ dkv_ptr;
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
// size_t o_stride_in_elts;
// size_t o_stride_in_bytes;
uint32_t o_stride_in_elts;
uint32_t o_stride_in_bytes;
uint32_t o_row_stride_in_elts;
uint32_t o_head_stride_in_elts;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;
......
......@@ -39,14 +39,13 @@ template<
// The number of rows of Q, K or V loaded by this tile.
int ROWS_,
// The number of columns.
int COLS,
// The number of matrics.
int NUM_MATS = 3
int COLS
>
struct Gmem_tile_qkv {
using Cta_tile = Cta_tile_;
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
// The size of each LDG.
static constexpr int BYTES_PER_LDG = 16;
// The size of a row in bytes.
......@@ -62,11 +61,12 @@ struct Gmem_tile_qkv {
static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);
// Ctor.
template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
template< typename BInfo >
inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr))
, ptr(reinterpret_cast<char *>(ptr_))
, tidx_(tidx) {
// Compute the position in the sequence (within the CTA for the moment).
......@@ -80,13 +80,13 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t row_offset = (uint32_t)row * params.qkv_stride_in_bytes;
uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes);
// Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (uint32_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
ptr += row_offset + col * BYTES_PER_LDG;
}
// Store data to shared memory.
......@@ -101,8 +101,8 @@ struct Gmem_tile_qkv {
uint32_t preds[LDGS];
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
ptrs[ii] = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0);
}
......@@ -120,32 +120,25 @@ struct Gmem_tile_qkv {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
char *ptr = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
fmha::stg(ptr, data[ii]);
fmha::stg(ptr_, data[ii]);
}
}
}
// Move the pointer to the next location.
inline __device__ void move() {
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;
qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_;
actual_seqlen -= ROWS;
}
inline __device__ void move(int steps) {
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_ * steps;
inline __device__ void move(const int steps = 1) {
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t params_qkv_stride_in_bytes_;
uint32_t params_qkv_stride_in_bytes_;
// int64_t row_stride_in_bytes;
const uint32_t row_stride_in_bytes;
// The pointer.
char *qkv_ptr_;
char *ptr;
// The fetch registers.
uint4 fetch_[LDGS];
// Keep track of the row the thread is processing as we move the tile.
......@@ -196,10 +189,10 @@ struct Gmem_tile_o {
// Ctor.
template<typename BInfo>
// inline __device__ Gmem_tile_o(void *ptr, const size_t stride_in_elts, const BInfo &binfo, const int tidx)
inline __device__ Gmem_tile_o(void *ptr, const uint32_t stride_in_elts, const BInfo &binfo, const int tidx)
: stride_in_bytes_(stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen_(binfo.actual_seqlen)
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(binfo.actual_seqlen)
, ptr_(reinterpret_cast<char *>(ptr))
, tidx_(tidx) {
......@@ -213,8 +206,9 @@ struct Gmem_tile_o {
// row_ = row;
// The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes);
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer.
ptr_ += row_offset + col * BYTES_PER_STG;
......@@ -224,25 +218,19 @@ struct Gmem_tile_o {
}
}
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_o(const Params &params, const BInfo &binfo, const int tidx)
: Gmem_tile_o(params.o_ptr, params.o_stride_in_elts, binfo, tidx) {}
// Store data to global memory.
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
// if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {
// break;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) {
break;
}
if (BYTES_PER_ELEMENT == 4) {
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, src[ii]);
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]);
}
} else if (BYTES_PER_ELEMENT == 2) {
float x = reinterpret_cast<const float &>(src[ii].x);
......@@ -251,7 +239,7 @@ struct Gmem_tile_o {
float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = float4_to_half4(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, out);
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
}
}
}
......@@ -269,37 +257,26 @@ struct Gmem_tile_o {
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_);
}
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
}
}
// Move the pointer to the next location.
inline __device__ void move() {
// row_ += ROWS;
// ptr_ += (int64_t)ROWS * stride_in_bytes_;
ptr_ += (uint32_t)ROWS * stride_in_bytes_;
actual_seqlen -= ROWS;
}
inline __device__ void move(const int steps) {
inline __device__ void move(const int steps = 1) {
// row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * stride_in_bytes_ * steps;
ptr_ += (uint32_t)ROWS * stride_in_bytes_ * steps;
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t stride_in_bytes_;
uint32_t stride_in_bytes_;
// int64_t row_stride_in_bytes;
const uint32_t row_stride_in_bytes;
// The pointer.
char *ptr_;
// Is the thread active for the last STG?
int is_active_for_last_stg_;
// Keep track of the row to disable loads.
// int row_;
// The length of the sequence loaded by that memory tile.
const int actual_seqlen_;
int actual_seqlen;
const int tidx_;
};
......@@ -363,10 +340,7 @@ struct Gmem_tile_mma_sd {
}
// Move to the next tile.
inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES;
}
inline __device__ void move(const int steps) {
inline __device__ void move(const int steps = 1) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
......@@ -459,69 +433,6 @@ struct Gmem_tile_mma_s : public Base {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The base class.
typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>
>
struct Gmem_tile_dout : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dout(void *ptr, const Params &params, const BInfo &binfo, int tidx)
: Base(params, 0, binfo, tidx) {
// this->qkv_ptr_ = reinterpret_cast<char *>(params.do_ptr);
this->qkv_ptr_ = static_cast<char *>(ptr);
this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / Base::THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >
struct Gmem_tile_dq : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dq(const Params &params, const int qkv_offset, const BInfo &binfo, int tidx)
: Base(params.dqkv_ptr, params.qkv_stride_in_elts, binfo, tidx) {
this->ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / Base::THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * this->stride_in_bytes_ +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * this->stride_in_bytes_ +
((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->ptr_ += row_offset + col * Base::BYTES_PER_STG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile
......
......@@ -72,9 +72,7 @@ struct FMHA_kernel_traits {
// The shared memory tile to transpose S.
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>;
using Gmem_tile_dot = fmha::Gmem_tile_dout<Cta_tile_p, fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D> >;
using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// The global memory tile to store the softmax sum.
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
......
......@@ -77,8 +77,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
using Gmem_tile_o = Gmem_tile_do;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using Gmem_tile_dq = fmha::Gmem_tile_dq<Cta_tile_dq>;
using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
// The shared memory tile to swizzle dQ.
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
......@@ -139,19 +138,19 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params, 0, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
......@@ -161,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx);
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx);
......@@ -173,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx);
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
......@@ -703,11 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
if (!Is_first) {
gmem_dv.move(loop_step_idx);
}
......@@ -718,11 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
if (!Is_first) {
gmem_dk.move(loop_step_idx);
}
......
......@@ -97,10 +97,10 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
......@@ -122,9 +122,9 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
......
......@@ -80,8 +80,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
using Gmem_tile_o = Gmem_tile_do;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using Gmem_tile_dq = fmha::Gmem_tile_dq<Cta_tile_dq>;
using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
// The shared memory tile to swizzle dQ.
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
......@@ -132,19 +131,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params, 0, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
......@@ -154,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx);
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx);
......@@ -166,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx);
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
......@@ -654,11 +653,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
if (!Is_first) {
gmem_dv.move(loop_step_idx);
}
......@@ -669,11 +664,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
if (!Is_first) {
gmem_dk.move(loop_step_idx);
}
......
......@@ -247,10 +247,10 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
......@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment