Commit f43ed837 authored by danyao12's avatar danyao12
Browse files

pt1 fix for time kernel ,not save z matrix

parent 24bf4051
...@@ -30,7 +30,6 @@ Kernel outputs: ...@@ -30,7 +30,6 @@ Kernel outputs:
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <fstream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -264,8 +263,8 @@ int run(int argc, char* argv[]) ...@@ -264,8 +263,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 512; // 512
ck::index_t N = 512; ck::index_t N = 512; // 512
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 64; ck::index_t O = 64;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4; // 54
...@@ -501,14 +500,60 @@ int run(int argc, char* argv[]) ...@@ -501,14 +500,60 @@ int run(int argc, char* argv[])
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero(); kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
// z_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
// get z matrix
{
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
invoker.Run(argument, StreamConfig{nullptr, false});
}
// not need output z matrix
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
...@@ -540,14 +585,8 @@ int run(int argc, char* argv[]) ...@@ -540,14 +585,8 @@ int run(int argc, char* argv[])
YElementOp{}, YElementOp{},
p_drop, p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
if(!gemm.IsSupportedArgument(argument)) vgrad_device_buf.SetZero();
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// 5 GEMM ops in total: // 5 GEMM ops in total:
...@@ -570,9 +609,8 @@ int run(int argc, char* argv[]) ...@@ -570,9 +609,8 @@ int run(int argc, char* argv[])
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
// copy z matirx data form device // copy z matirx data form device
std::ofstream file("./z_matrix_txt");
z_device_buf.FromDevice(z_g_m_n.mData.data()); z_device_buf.FromDevice(z_g_m_n.mData.data());
file << z_g_m_n << std::endl;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n; // std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
...@@ -595,45 +633,7 @@ int run(int argc, char* argv[]) ...@@ -595,45 +633,7 @@ int run(int argc, char* argv[])
}); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
//
// call kernel again // call kernel again
//
// example set Z matrix to null, will not ouput z matrix data
argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment