Commit 96f57e07 authored by Anthony Chang's avatar Anthony Chang
Browse files

tighten up example code

parent d9708da8
...@@ -105,8 +105,8 @@ using DeviceGemmInstance = ...@@ -105,8 +105,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -114,7 +114,7 @@ using DeviceGemmInstance = ...@@ -114,7 +114,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -129,7 +129,7 @@ using DeviceGemmInstance = ...@@ -129,7 +129,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
S<16, 16, 1>, // B1BlockTransfer S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
...@@ -137,7 +137,7 @@ using DeviceGemmInstance = ...@@ -137,7 +137,7 @@ using DeviceGemmInstance =
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
...@@ -235,21 +235,22 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -235,21 +235,22 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[]) int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = false; bool time_kernel = true;
// Overall QKV matrices shape // Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 256; ck::index_t M = 512;
ck::index_t N = 256; ck::index_t N = 512;
ck::index_t K = 256; ck::index_t K = 128;
ck::index_t O = 256; ck::index_t O = 128;
ck::index_t G0 = 1; ck::index_t G0 = 3;
ck::index_t G1 = 1; ck::index_t G1 = 2;
float alpha = 1; // float alpha = 1.f / std::sqrt(K); // TODO: make scaling aware
float alpha = 1.f;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -357,7 +358,7 @@ int run(int argc, char* argv[]) ...@@ -357,7 +358,7 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
break; break;
case 3: case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
...@@ -487,6 +488,10 @@ int run(int argc, char* argv[]) ...@@ -487,6 +488,10 @@ int run(int argc, char* argv[])
return 0; return 0;
} }
if(alpha != 1.0f)
{
std::cout << "not yet implemented scaling" << std::endl; // TODO: make scaling aware
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
...@@ -512,6 +517,10 @@ int run(int argc, char* argv[]) ...@@ -512,6 +517,10 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment