Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
16e3f66a
Commit
16e3f66a
authored
Jun 08, 2022
by
ltqin
Browse files
fix bug
parent
263589eb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
12 deletions
+16
-12
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
+5
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
+11
-8
No files found.
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
View file @
16e3f66a
...
@@ -54,7 +54,8 @@ using BDataType = ck::half_t;
...
@@ -54,7 +54,8 @@ using BDataType = ck::half_t;
using CDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using AccDataType = float;
#else
#else
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
16
,
64
,
4
,
4
,
16
,
16
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
4
,
7
,
1
>
;
<
F32
,
F32
,
F32
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
1
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
true
,
1
,
7
,
1
>
;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 32, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 32, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
using
ADataType
=
float
;
using
ADataType
=
float
;
...
@@ -87,10 +88,10 @@ template <typename DataType>
...
@@ -87,10 +88,10 @@ template <typename DataType>
std
::
ostream
&
show_2d_matrix
(
std
::
ostream
&
os
,
Tensor
<
DataType
>&
matrix
)
std
::
ostream
&
show_2d_matrix
(
std
::
ostream
&
os
,
Tensor
<
DataType
>&
matrix
)
{
{
os
<<
"["
<<
std
::
endl
;
os
<<
"["
<<
std
::
endl
;
for
(
in
t
x
=
0
;
x
<
matrix
.
mDesc
.
GetLengths
()[
0
];
x
++
)
for
(
size_
t
x
=
0
;
x
<
matrix
.
mDesc
.
GetLengths
()[
0
];
x
++
)
{
{
os
<<
"["
;
os
<<
"["
;
for
(
in
t
y
=
0
;
y
<
matrix
.
mDesc
.
GetLengths
()[
1
];
y
++
)
for
(
size_
t
y
=
0
;
y
<
matrix
.
mDesc
.
GetLengths
()[
1
];
y
++
)
{
{
os
<<
std
::
setw
(
5
)
<<
static_cast
<
float
>
(
matrix
(
x
,
y
));
os
<<
std
::
setw
(
5
)
<<
static_cast
<
float
>
(
matrix
(
x
,
y
));
}
}
...
@@ -117,7 +118,7 @@ int main(int argc, char* argv[])
...
@@ -117,7 +118,7 @@ int main(int argc, char* argv[])
#else
#else
ck
::
index_t
M
=
16
;
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
8
;
ck
::
index_t
K
=
32
;
ck
::
index_t
StrideA
=
8
;
ck
::
index_t
StrideA
=
8
;
ck
::
index_t
StrideB
=
8
;
ck
::
index_t
StrideB
=
8
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
View file @
16e3f66a
...
@@ -113,7 +113,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -113,7 +113,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
BaseMultK0
=
4
;
static
constexpr
auto
BaseMultK0
=
4
;
static
constexpr
auto
MultiK0
=
4
*
2
;
static
constexpr
auto
MultiK0
=
BaseMultK0
*
2
;
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
@@ -192,7 +192,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -192,7 +192,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// 2-stage prefetch currently only support even number of K0 loop
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
if
(
!
((
K0
/
K0PerBlock
)
%
MultiK0
==
0
))
{
{
return
false
;
return
false
;
}
}
...
@@ -573,6 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -573,6 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 2nd
// 2nd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
@@ -583,8 +584,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -583,8 +584,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 3rd
// 3rd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
@@ -595,8 +596,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -595,8 +596,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 4th
// 4th
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
@@ -607,8 +608,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -607,8 +608,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
});
});
block_sync_lds
();
block_sync_lds
();
...
@@ -637,6 +638,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -637,6 +638,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 2nd
// 2nd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
...
@@ -647,8 +649,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -647,8 +649,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 3rd
// 3rd
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
...
@@ -662,8 +664,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -662,8 +664,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
}
}
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 4th
// 4th
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
...
@@ -677,8 +680,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
...
@@ -677,8 +680,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
}
}
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
});
});
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment