"docs/vscode:/vscode.git/clone" did not exist on "5710567ce3dc1f9aab32a910d79d06a88f95f56d"
Commit 6d0e78bd authored by Jing Zhang's avatar Jing Zhang
Browse files

improve weight layout

parent 5d42067e
...@@ -40,8 +40,7 @@ using DeviceGemmV2Instance = ...@@ -40,8 +40,7 @@ using DeviceGemmV2Instance =
1, 1, S<1, 16, 1, 4>, 4, 1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
static int NPerBlock = 16; [[maybe_unused]] static int KPerBlock = 256;
static int KPerBlock = 256;
#else #else
128, 128,
16, 32, 16, 32,
...@@ -53,10 +52,9 @@ using DeviceGemmV2Instance = ...@@ -53,10 +52,9 @@ using DeviceGemmV2Instance =
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4, 1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
static int NPerBlock = 32; [[maybe_unused]]static int KPerBlock = 128;
static int KPerBlock = 128;
#endif #endif
// clang-format on // clang-format on
...@@ -125,7 +123,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -125,7 +123,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1}); b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break; break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break; break;
case 2: case 2:
...@@ -153,31 +151,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -153,31 +151,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute // weight permute
#if 0 #if 1
int N1 = NPerBlock;
int K1 = KPerBlock; int K1 = KPerBlock;
int K0 = K / KPerBlock;
int N0 = N / N1; // int K0, N, K1
int K0 = K / K1; for(int j = 0; j < K0; j++)
int K01 = K0 / KBatch;
int K00 = KBatch;
std::cout << "K00 = " << K00 << " K01 = " << K01 << std::endl;
for(int k = 0; k < K00; k++)
{ {
for(int i = 0; i < N0; i++) for(int i = 0; i < N; i++)
{ {
for(int j = 0; j < K01; j++) for(int jj = 0; jj < K1; jj++)
{ {
for(int ii = 0; ii < N1; ii++) b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(k * N0 * K01 * N1 * K1 + i * K01 * N1 * K1 + j * N1 * K1 + ii * K1 + jj) =
b_k_n((i * N1 + ii) * K + (k * K01 * K1 + j * K1 + jj));
}
}
} }
} }
} }
...@@ -286,7 +271,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -286,7 +271,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result, pass &= ck::utils::check_err(c_m_n_device_result,
......
...@@ -50,7 +50,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -50,7 +50,7 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
auto h_f16 = ck::type_convert<ck::half_t>(x_h - 8); auto h_f16 = ck::type_convert<ck::half_t>(x_h - 8);
return {h_f16, l_f16}; return {h_f16, l_f16};
#else #elif 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q); uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
int x_l = (x_u8 & 0x0f); int x_l = (x_u8 & 0x0f);
...@@ -62,6 +62,9 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) ...@@ -62,6 +62,9 @@ __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
int lo = (x_l | x_h) | EX; int lo = (x_l | x_h) | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#else
int32_t res = bit_cast<int8_t>(q);
return bit_cast<half2_t>(res);
#endif #endif
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
//#define WEIGHT_PERMUTE #define WEIGHT_PERMUTE
namespace ck { namespace ck {
...@@ -399,22 +399,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -399,22 +399,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
#else #else
constexpr index_t N1 = NPerBlock;
constexpr index_t BK01 = KPerBlock / BK1Value; constexpr index_t BK01 = KPerBlock / BK1Value;
const index_t BK00 = BK0 / BK01; const index_t BK00 = BK0 / BK01;
const index_t N0 = N / N1;
const auto b_grid_desc_n0_bk00_n1_bk01_bk1 = const auto b_grid_desc_bk00_n_bk01_bk1 =
make_naive_tensor_descriptor_packed(make_tuple(N0, BK00, N1, BK01, BK1Value)); make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n0_bk00_n1_bk01_bk1, b_grid_desc_bk00_n_bk01_bk1,
make_tuple(make_merge_transform(make_tuple(BK00, BK01)), make_tuple(make_merge_transform(make_tuple(BK00, BK01)),
make_merge_transform(make_tuple(N0, N1)), make_pass_through_transform(make_tuple(N)),
make_pass_through_transform(BK1Value)), make_pass_through_transform(BK1Value)),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif #endif
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment