"experiments/pyexps/ScaleHost.py" did not exist on "84da17e3848d356b0efafaf267c3a2cb73b0af50"
Commit 96f57e07 authored by Anthony Chang's avatar Anthony Chang
Browse files

tighten up example code

parent d9708da8
......@@ -105,8 +105,8 @@ using DeviceGemmInstance =
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
......@@ -114,7 +114,7 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......@@ -129,7 +129,7 @@ using DeviceGemmInstance =
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
......@@ -137,7 +137,7 @@ using DeviceGemmInstance =
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
......@@ -235,21 +235,22 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 256;
ck::index_t N = 256;
ck::index_t K = 256;
ck::index_t O = 256;
ck::index_t G0 = 1;
ck::index_t G1 = 1;
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = 128;
ck::index_t O = 128;
ck::index_t G0 = 3;
ck::index_t G1 = 2;
float alpha = 1;
// float alpha = 1.f / std::sqrt(K); // TODO: make scaling aware
float alpha = 1.f;
bool input_permute = false;
bool output_permute = false;
......@@ -357,7 +358,7 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2});
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
......@@ -487,6 +488,10 @@ int run(int argc, char* argv[])
return 0;
}
if(alpha != 1.0f)
{
std::cout << "not yet implemented scaling" << std::endl; // TODO: make scaling aware
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......@@ -512,6 +517,10 @@ int run(int argc, char* argv[])
bool pass = true;
if(do_verification)
{
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment