Commit 08204531 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] fix build on windows

parent b1fec976
......@@ -1449,17 +1449,19 @@ public:
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales);
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]);
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum);
if constexpr (rank > 0) {
lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales);
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
for (int m = 0; m < LORA_M_TILES; m++) {
for (int n = 0; n < LORA_N_TILES; n++) {
packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]);
for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum);
}
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
}
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
}
}
}
......@@ -1498,42 +1500,41 @@ public:
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
if constexpr (rank > 0) {
lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
// clock_t dummy = 0;
// clock_t dummy = 0;
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
#pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum);
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum");
CHECK_NAN(psum, "apply_lora_down.psum");
}
}
}
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
// unused_var(dummy, alwaysfalse);
// unused_var(dummy, alwaysfalse);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment