Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
08204531
Commit
08204531
authored
Nov 10, 2024
by
sxtyzhangzk
Committed by
Zhekai Zhang
Nov 10, 2024
Browse files
[major] fix build on windows
parent
b1fec976
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
36 deletions
+37
-36
src/kernels/gemm_w4a4.cu
src/kernels/gemm_w4a4.cu
+37
-36
No files found.
src/kernels/gemm_w4a4.cu
View file @
08204531
...
...
@@ -1449,17 +1449,19 @@ public:
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_act16_warp
lora_act
=
load_lora_act
(
act
+
warpId
*
(
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
scales
);
lora_wgt_warp
lora_wgt
=
load_lora_wgt
(
wgt
);
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
packed_f32psum_t
psum
=
packed_fp16_to_fp32
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
psum
=
mma_f16xf16_f32
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
psum
);
if
constexpr
(
rank
>
0
)
{
lora_act16_warp
lora_act
=
load_lora_act
(
act
+
warpId
*
(
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
scales
);
lora_wgt_warp
lora_wgt
=
load_lora_wgt
(
wgt
);
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
packed_f32psum_t
psum
=
packed_fp16_to_fp32
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
psum
=
mma_f16xf16_f32
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
psum
);
}
fpsum
[
m
*
WARP_N_TILES
+
n
]
=
packed_fp32_to_fp16
(
psum
);
}
fpsum
[
m
*
WARP_N_TILES
+
n
]
=
packed_fp32_to_fp16
(
psum
);
}
}
}
...
...
@@ -1498,42 +1500,41 @@ public:
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_act_warp
lora_act
;
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
if
constexpr
(
rank
>
0
)
{
lora_act_warp
lora_act
;
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
lora_wgt_warp
lora_wgt
=
load_lora_wgt
(
wgt
);
lora_wgt_warp
lora_wgt
=
load_lora_wgt
(
wgt
);
// clock_t dummy = 0;
// clock_t dummy = 0;
#pragma unroll
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
#pragma unroll
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
#pragma unroll
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
#pragma unroll
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
CHECK_NAN
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
"apply_lora_down.fpsum"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"apply_lora_down.lora_wgt"
);
CHECK_NAN
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
"apply_lora_down.fpsum"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"apply_lora_down.lora_wgt"
);
psum
=
mma_f16xf16_f32
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
psum
);
psum
=
mma_f16xf16_f32
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
psum
);
CHECK_NAN
(
psum
,
"apply_lora_down.psum"
);
CHECK_NAN
(
psum
,
"apply_lora_down.psum"
);
}
}
}
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act
(
act
+
warpId
*
(
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
lora_act
);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act
(
act
+
warpId
*
(
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
lora_act
);
// unused_var(dummy, alwaysfalse);
// unused_var(dummy, alwaysfalse);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment