Commit 08204531 authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] fix build on windows

parent b1fec976
...@@ -1449,17 +1449,19 @@ public: ...@@ -1449,17 +1449,19 @@ public:
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales); if constexpr (rank > 0) {
lora_wgt_warp lora_wgt = load_lora_wgt(wgt); lora_act16_warp lora_act = load_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), scales);
for (int m = 0; m < LORA_M_TILES; m++) { lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
for (int n = 0; n < LORA_N_TILES; n++) { for (int m = 0; m < LORA_M_TILES; m++) {
packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]); for (int n = 0; n < LORA_N_TILES; n++) {
for (int r = 0; r < LORA_R_TILES; r++) { packed_f32psum_t psum = packed_fp16_to_fp32(fpsum[m * WARP_N_TILES + n]);
CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act"); for (int r = 0; r < LORA_R_TILES; r++) {
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt"); CHECK_NAN(lora_act[m * LORA_R_TILES + r], "lora_act");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum); CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "lora_wgt");
psum = mma_f16xf16_f32(lora_act[m * LORA_R_TILES + r], lora_wgt[n * LORA_R_TILES + r], psum);
}
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
} }
fpsum[m * WARP_N_TILES + n] = packed_fp32_to_fp16(psum);
} }
} }
} }
...@@ -1498,42 +1500,41 @@ public: ...@@ -1498,42 +1500,41 @@ public:
const int laneId = threadIdx.x % WARP_SIZE; const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE; const int warpId = threadIdx.x / WARP_SIZE;
lora_act_warp lora_act; if constexpr (rank > 0) {
lora_act.fill(packed_f32psum_t::zeros()); lora_act_warp lora_act;
lora_act.fill(packed_f32psum_t::zeros());
lora_wgt_warp lora_wgt = load_lora_wgt(wgt); lora_wgt_warp lora_wgt = load_lora_wgt(wgt);
// clock_t dummy = 0; // clock_t dummy = 0;
#pragma unroll #pragma unroll
for (int m = 0; m < LORA_M_TILES; m++) { for (int m = 0; m < LORA_M_TILES; m++) {
#pragma unroll #pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) { for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll #pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) { for (int r = 0; r < LORA_R_TILES; r++) {
auto &psum = lora_act[m * LORA_R_TILES + r]; auto &psum = lora_act[m * LORA_R_TILES + r];
CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum"); CHECK_NAN(fpsum[m * WARP_N_TILES + n], "apply_lora_down.fpsum");
CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt"); CHECK_NAN(lora_wgt[n * LORA_R_TILES + r], "apply_lora_down.lora_wgt");
psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum); psum = mma_f16xf16_f32(fpsum[m * WARP_N_TILES + n], lora_wgt[n * LORA_R_TILES + r], psum);
CHECK_NAN(psum, "apply_lora_down.psum"); CHECK_NAN(psum, "apply_lora_down.psum");
}
} }
} // reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// if (alwaysfalse) { // if (alwaysfalse) {
// dummy = clock(); // dummy = clock();
// } // }
} }
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act);
// unused_var(dummy, alwaysfalse); // unused_var(dummy, alwaysfalse);
}
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment