Commit 5d07483b authored by Tri Dao's avatar Tri Dao
Browse files

Refactor Gmem code to store q, k, v pointers separately

parent d3e64409
...@@ -56,12 +56,18 @@ void set_params(Fused_multihead_attention_fprop_params &params, ...@@ -56,12 +56,18 @@ void set_params(Fused_multihead_attention_fprop_params &params,
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
// Set the pointers and strides. // Set the pointers and strides.
params.qkv_ptr = qkv_packed_d; params.q_ptr = qkv_packed_d;
params.qkv_stride_in_elts = h * 3 * d; params.k_ptr = qkv_packed_d + get_size_in_bytes(h * d, data_type);
params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type); params.v_ptr = qkv_packed_d + 2 * get_size_in_bytes(h * d, data_type);
params.q_row_stride_in_elts = 3 * h * d;
params.k_row_stride_in_elts = 3 * h * d;
params.v_row_stride_in_elts = 3 * h * d;
params.q_head_stride_in_elts = d;
params.k_head_stride_in_elts = d;
params.v_head_stride_in_elts = d;
params.o_ptr = o_packed_d; params.o_ptr = o_packed_d;
params.o_stride_in_elts = h * d; params.o_row_stride_in_elts = h * d;
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); params.o_head_stride_in_elts = d;
params.do_ptr = do_packed_d; params.do_ptr = do_packed_d;
params.o_tmp_ptr = o_tmp_d; params.o_tmp_ptr = o_tmp_d;
......
...@@ -50,15 +50,21 @@ constexpr int D_DIM = 3; ...@@ -50,15 +50,21 @@ constexpr int D_DIM = 3;
struct Qkv_params { struct Qkv_params {
// The QKV matrices. // The QKV matrices.
void * __restrict__ qkv_ptr; void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices. // The stride between rows of the Q, K and V matrices.
// size_t qkv_stride_in_elts; // size_t qkv_stride_in_elts;
// size_t qkv_stride_in_bytes; // size_t qkv_stride_in_bytes;
// TD [2022-04-16]: We're using 32-bit indexing to save registers. // TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB. // The code probably won't work for arrays larger than 2GB.
uint32_t qkv_stride_in_elts; uint32_t q_row_stride_in_elts;
uint32_t qkv_stride_in_bytes; uint32_t k_row_stride_in_elts;
uint32_t v_row_stride_in_elts;
uint32_t q_head_stride_in_elts;
uint32_t k_head_stride_in_elts;
uint32_t v_head_stride_in_elts;
// The number of heads. // The number of heads.
int h; int h;
...@@ -71,17 +77,14 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -71,17 +77,14 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices. // The dQKV matrices.
void * __restrict__ dqkv_ptr; void * __restrict__ dqkv_ptr;
// Temporary for dKV.
void * __restrict__ dkv_ptr;
// The O matrix (output). // The O matrix (output).
void * __restrict__ o_ptr; void * __restrict__ o_ptr;
// The stride between rows of O. // The stride between rows of O.
// size_t o_stride_in_elts; // size_t o_stride_in_elts;
// size_t o_stride_in_bytes; // size_t o_stride_in_bytes;
uint32_t o_stride_in_elts; uint32_t o_row_stride_in_elts;
uint32_t o_stride_in_bytes; uint32_t o_head_stride_in_elts;
// The pointer to the O_tmp matrix, which holds O intermediate value during // The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop; // the loop;
...@@ -171,4 +174,4 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para ...@@ -171,4 +174,4 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
void run_fmha_block_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure); void run_fmha_block_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream); void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
\ No newline at end of file
...@@ -39,14 +39,13 @@ template< ...@@ -39,14 +39,13 @@ template<
// The number of rows of Q, K or V loaded by this tile. // The number of rows of Q, K or V loaded by this tile.
int ROWS_, int ROWS_,
// The number of columns. // The number of columns.
int COLS, int COLS
// The number of matrics.
int NUM_MATS = 3
> >
struct Gmem_tile_qkv { struct Gmem_tile_qkv {
using Cta_tile = Cta_tile_; using Cta_tile = Cta_tile_;
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
// The size of each LDG. // The size of each LDG.
static constexpr int BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = 16;
// The size of a row in bytes. // The size of a row in bytes.
...@@ -62,11 +61,12 @@ struct Gmem_tile_qkv { ...@@ -62,11 +61,12 @@ struct Gmem_tile_qkv {
static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG); static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);
// Ctor. // Ctor.
template< typename Params, typename BInfo > template< typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx) inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes) const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(binfo.actual_seqlen) , actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) , ptr(reinterpret_cast<char *>(ptr_))
, tidx_(tidx) { , tidx_(tidx) {
// Compute the position in the sequence (within the CTA for the moment). // Compute the position in the sequence (within the CTA for the moment).
...@@ -80,13 +80,13 @@ struct Gmem_tile_qkv { ...@@ -80,13 +80,13 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order. // The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t row_offset = (uint32_t)row * params.qkv_stride_in_bytes; uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes);
// Add the block index. // Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (uint32_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer. // Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG; ptr += row_offset + col * BYTES_PER_LDG;
} }
// Store data to shared memory. // Store data to shared memory.
...@@ -101,8 +101,8 @@ struct Gmem_tile_qkv { ...@@ -101,8 +101,8 @@ struct Gmem_tile_qkv {
uint32_t preds[LDGS]; uint32_t preds[LDGS];
#pragma unroll #pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) { for( int ii = 0; ii < LDGS; ++ii ) {
// ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; // ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
ptrs[ii] = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0); fetch_[ii] = make_uint4(0, 0, 0, 0);
} }
...@@ -120,32 +120,25 @@ struct Gmem_tile_qkv { ...@@ -120,32 +120,25 @@ struct Gmem_tile_qkv {
int row_ = tidx_ / THREADS_PER_ROW; int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll #pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) { for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
char *ptr = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
fmha::stg(ptr, data[ii]); fmha::stg(ptr_, data[ii]);
} }
} }
} }
// Move the pointer to the next location. inline __device__ void move(const int steps = 1) {
inline __device__ void move() { // ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_;
actual_seqlen -= ROWS;
}
inline __device__ void move(int steps) {
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps; actual_seqlen -= ROWS * steps;
} }
// The stride between rows for the QKV matrice. // The stride between rows for the QKV matrice.
// int64_t params_qkv_stride_in_bytes_; // int64_t row_stride_in_bytes;
uint32_t params_qkv_stride_in_bytes_; const uint32_t row_stride_in_bytes;
// The pointer. // The pointer.
char *qkv_ptr_; char *ptr;
// The fetch registers. // The fetch registers.
uint4 fetch_[LDGS]; uint4 fetch_[LDGS];
// Keep track of the row the thread is processing as we move the tile. // Keep track of the row the thread is processing as we move the tile.
...@@ -196,10 +189,10 @@ struct Gmem_tile_o { ...@@ -196,10 +189,10 @@ struct Gmem_tile_o {
// Ctor. // Ctor.
template<typename BInfo> template<typename BInfo>
// inline __device__ Gmem_tile_o(void *ptr, const size_t stride_in_elts, const BInfo &binfo, const int tidx) // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
inline __device__ Gmem_tile_o(void *ptr, const uint32_t stride_in_elts, const BInfo &binfo, const int tidx) inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
: stride_in_bytes_(stride_in_elts * BYTES_PER_ELEMENT) const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx)
, actual_seqlen_(binfo.actual_seqlen) : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(binfo.actual_seqlen) , actual_seqlen(binfo.actual_seqlen)
, ptr_(reinterpret_cast<char *>(ptr)) , ptr_(reinterpret_cast<char *>(ptr))
, tidx_(tidx) { , tidx_(tidx) {
...@@ -213,8 +206,9 @@ struct Gmem_tile_o { ...@@ -213,8 +206,9 @@ struct Gmem_tile_o {
// row_ = row; // row_ = row;
// The row offset in the batched GEMM. // The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW; // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW; uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes);
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer. // Assemble the final pointer.
ptr_ += row_offset + col * BYTES_PER_STG; ptr_ += row_offset + col * BYTES_PER_STG;
...@@ -224,25 +218,19 @@ struct Gmem_tile_o { ...@@ -224,25 +218,19 @@ struct Gmem_tile_o {
} }
} }
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_o(const Params &params, const BInfo &binfo, const int tidx)
: Gmem_tile_o(params.o_ptr, params.o_stride_in_elts, binfo, tidx) {}
// Store data to global memory. // Store data to global memory.
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
int row_ = tidx_ / THREADS_PER_ROW; int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll #pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii; int jj = mi * STGS_PER_LOOP + ii;
// if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {
// break;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) { if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) {
break; break;
} }
if (BYTES_PER_ELEMENT == 4) { if (BYTES_PER_ELEMENT == 4) {
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, src[ii]); fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]);
} }
} else if (BYTES_PER_ELEMENT == 2) { } else if (BYTES_PER_ELEMENT == 2) {
float x = reinterpret_cast<const float &>(src[ii].x); float x = reinterpret_cast<const float &>(src[ii].x);
...@@ -251,7 +239,7 @@ struct Gmem_tile_o { ...@@ -251,7 +239,7 @@ struct Gmem_tile_o {
float w = reinterpret_cast<const float &>(src[ii].w); float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = float4_to_half4(x, y, z, w); uint2 out = float4_to_half4(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, out); fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
} }
} }
} }
...@@ -269,37 +257,26 @@ struct Gmem_tile_o { ...@@ -269,37 +257,26 @@ struct Gmem_tile_o {
} }
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_); fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
} }
} }
} }
// Move the pointer to the next location. inline __device__ void move(const int steps = 1) {
inline __device__ void move() {
// row_ += ROWS;
// ptr_ += (int64_t)ROWS * stride_in_bytes_;
ptr_ += (uint32_t)ROWS * stride_in_bytes_;
actual_seqlen -= ROWS;
}
inline __device__ void move(const int steps) {
// row_ += ROWS * steps; // row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * stride_in_bytes_ * steps; // ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr_ += (uint32_t)ROWS * stride_in_bytes_ * steps; ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen -= ROWS * steps; actual_seqlen -= ROWS * steps;
} }
// The stride between rows for the QKV matrice. // The stride between rows for the QKV matrice.
// int64_t stride_in_bytes_; // int64_t row_stride_in_bytes;
uint32_t stride_in_bytes_; const uint32_t row_stride_in_bytes;
// The pointer. // The pointer.
char *ptr_; char *ptr_;
// Is the thread active for the last STG? // Is the thread active for the last STG?
int is_active_for_last_stg_; int is_active_for_last_stg_;
// Keep track of the row to disable loads.
// int row_;
// The length of the sequence loaded by that memory tile. // The length of the sequence loaded by that memory tile.
const int actual_seqlen_;
int actual_seqlen; int actual_seqlen;
const int tidx_; const int tidx_;
}; };
...@@ -363,10 +340,7 @@ struct Gmem_tile_mma_sd { ...@@ -363,10 +340,7 @@ struct Gmem_tile_mma_sd {
} }
// Move to the next tile. // Move to the next tile.
inline __device__ void move() { inline __device__ void move(const int steps = 1) {
ptr_ += LOOP_STRIDE_BYTES;
}
inline __device__ void move(const int steps) {
ptr_ += LOOP_STRIDE_BYTES * steps; ptr_ += LOOP_STRIDE_BYTES * steps;
} }
...@@ -459,69 +433,6 @@ struct Gmem_tile_mma_s : public Base { ...@@ -459,69 +433,6 @@ struct Gmem_tile_mma_s : public Base {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The base class.
typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>
>
struct Gmem_tile_dout : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dout(void *ptr, const Params &params, const BInfo &binfo, int tidx)
: Base(params, 0, binfo, tidx) {
// this->qkv_ptr_ = reinterpret_cast<char *>(params.do_ptr);
this->qkv_ptr_ = static_cast<char *>(ptr);
this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / Base::THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >
struct Gmem_tile_dq : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dq(const Params &params, const int qkv_offset, const BInfo &binfo, int tidx)
: Base(params.dqkv_ptr, params.qkv_stride_in_elts, binfo, tidx) {
this->ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / Base::THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * this->stride_in_bytes_ +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * this->stride_in_bytes_ +
((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->ptr_ += row_offset + col * Base::BYTES_PER_STG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< template<
// The dimensions of the tile computed by the CTA. // The dimensions of the tile computed by the CTA.
typename Cta_tile typename Cta_tile
......
...@@ -72,9 +72,7 @@ struct FMHA_kernel_traits { ...@@ -72,9 +72,7 @@ struct FMHA_kernel_traits {
// The shared memory tile to transpose S. // The shared memory tile to transpose S.
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>; using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>; using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
using Gmem_tile_dot = fmha::Gmem_tile_dout<Cta_tile_p, fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D> >;
// The global memory tile to store the softmax sum. // The global memory tile to store the softmax sum.
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>; using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
......
...@@ -77,8 +77,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -77,8 +77,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
using Gmem_tile_o = Gmem_tile_do; using Gmem_tile_o = Gmem_tile_do;
// The global memory tile to store dQ. // The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq; using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_dq = fmha::Gmem_tile_dq<Cta_tile_dq>;
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>; using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
// The shared memory tile to swizzle dQ. // The shared memory tile to swizzle dQ.
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
...@@ -139,19 +138,19 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -139,19 +138,19 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for dQ. // Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params, 0, binfo, tidx); Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
...@@ -161,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -161,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO. // Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx); Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for dO. // Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx);
...@@ -173,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -173,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
...@@ -703,11 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -703,11 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Qkv_params dv_params; Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
if (!Is_first) { if (!Is_first) {
gmem_dv.move(loop_step_idx); gmem_dv.move(loop_step_idx);
} }
...@@ -718,11 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -718,11 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// } // }
Qkv_params dk_params; Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
} }
......
...@@ -97,10 +97,10 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -97,10 +97,10 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
Gemm1 gemm_q_k(smem_, tidx); Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
...@@ -122,12 +122,12 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -122,12 +122,12 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!! // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx); Smem_tile_v smem_v(smem_v_, tidx);
...@@ -193,7 +193,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -193,7 +193,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
__syncthreads(); __syncthreads();
} }
// Load the fragments for K. // Load the fragments for K.
gemm_q_k.load_k(); gemm_q_k.load_k();
// Create the object to do the softmax. // Create the object to do the softmax.
......
...@@ -80,8 +80,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -80,8 +80,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
using Gmem_tile_o = Gmem_tile_do; using Gmem_tile_o = Gmem_tile_do;
// The global memory tile to store dQ. // The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq; using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_dq = fmha::Gmem_tile_dq<Cta_tile_dq>;
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>; using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
// The shared memory tile to swizzle dQ. // The shared memory tile to swizzle dQ.
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
...@@ -132,19 +131,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -132,19 +131,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for dQ. // Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params, 0, binfo, tidx); Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
...@@ -154,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -154,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO. // Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx); Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for dO. // Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx);
...@@ -166,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -166,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
...@@ -654,11 +653,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -654,11 +653,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Qkv_params dv_params; Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
if (!Is_first) { if (!Is_first) {
gmem_dv.move(loop_step_idx); gmem_dv.move(loop_step_idx);
} }
...@@ -669,11 +664,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -669,11 +664,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// } // }
Qkv_params dk_params; Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
} }
......
...@@ -247,10 +247,10 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -247,10 +247,10 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Gemm1 gemm_q_k(smem_, tidx); Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
...@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment