"platforms/hip/src/HipKernelFactory.cpp" did not exist on "58b094cec72f74db91e131277804e82e168c16e0"
Unverified Commit b6516358 authored by ruizhang1230's avatar ruizhang1230 Committed by GitHub
Browse files

support hidden size 8192 (#264)

* support hidden size 8192

* refactor code

* fix assert
parent 486dd1d9
......@@ -499,9 +499,9 @@ combine(void* combined_x,
cg::this_grid().sync();
// Reduce tokens
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= num_threads);
EP_DEVICE_ASSERT(num_topk <= 32 and hidden_bf16_int4 <= 1024);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int k = thread_id; k < hidden_bf16_int4; k += num_threads) {
for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) {
// Read top-k indices and weights
int reg_topk_idx[kNumMaxTopk];
......@@ -520,7 +520,7 @@ combine(void* combined_x,
auto rdma_buffer_row = reinterpret_cast<const uint8_t*>(rdma_buffer_type);
// Reduce
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
auto x_vec = ld_nc_global(reinterpret_cast<const int4*>(rdma_buffer_row) + k);
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
......@@ -533,7 +533,7 @@ combine(void* combined_x,
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++ j)
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
(static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
(static_cast<int4*>(combined_x) + token_idx * hidden_bf16_int4)[k] = combined_int4;
}
}
}
......
......@@ -85,5 +85,6 @@ cfg.dynamicSmemBytes = smem_size;
case 4096: case_macro(4096); \
case 5120: case_macro(5120); \
case 7168: case_macro(7168); \
case 8192: case_macro(8192); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
} while (false)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment