Commit ce31e22a authored by guangzlu's avatar guangzlu
Browse files

modified batched_multihead_attention_train.cpp

parent 3a9dabcf
...@@ -32,7 +32,7 @@ Kernel outputs: ...@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -720,8 +720,8 @@ int run(int argc, char* argv[]) ...@@ -720,8 +720,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; // 512 ck::index_t M = 1000; // 512
ck::index_t N = 512; // 512 ck::index_t N = 1000; // 512
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4; // 54
...@@ -1041,8 +1041,8 @@ int run(int argc, char* argv[]) ...@@ -1041,8 +1041,8 @@ int run(int argc, char* argv[])
p_drop, p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
qgrad_device_buf.SetZero(); // reset global accum buffer and rerun qgrad_device_buf.SetZero(); // reset global accum buffer and rerun
kgrad_device_buf.SetZero(); // kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); // vgrad_device_buf.SetZero();
float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true}); float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true});
// 5 GEMM ops in total: // 5 GEMM ops in total:
...@@ -1151,8 +1151,8 @@ int run(int argc, char* argv[]) ...@@ -1151,8 +1151,8 @@ int run(int argc, char* argv[])
fwd_file << z_fwd_gs_ms_ns << std::endl; fwd_file << z_fwd_gs_ms_ns << std::endl;
qgrad_device_buf.SetZero(); qgrad_device_buf.SetZero();
kgrad_device_buf.SetZero(); // kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); // vgrad_device_buf.SetZero();
auto argument_bwd = gemm_bwd.MakeArgument( auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment