Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
869a738c
Commit
869a738c
authored
Aug 25, 2023
by
aska-0096
Browse files
fix a bug
parent
a97298a4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
6 deletions
+39
-6
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+12
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+27
-6
No files found.
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
869a738c
...
@@ -220,6 +220,18 @@ int main(int argc, char* argv[])
...
@@ -220,6 +220,18 @@ int main(int argc, char* argv[])
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
-
0.5
,
0.5
});
d_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
DDataType
>
{
-
0.5
,
0.5
});
}
}
#if 0
for(int im = 0; im<M; im++)
{
for(int ik = 0; ik<K; ik++)
{
if(ik%8==0) printf("|");
printf("%4x ", *(reinterpret_cast<uint16_t*>(&(a_m_k(im,ik)))));
}
printf("\n");
}
#endif
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
869a738c
...
@@ -550,6 +550,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -550,6 +550,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
ordered_src_access_stride
=
constexpr
auto
ordered_src_access_stride
=
container_reorder_given_new2old
(
src_access_stride
,
src_dim_access_order
);
container_reorder_given_new2old
(
src_access_stride
,
src_dim_access_order
);
constexpr
auto
ordered_src_access_unit
=
container_reorder_given_new2old
(
src_access_unit
,
src_dim_access_order
);
// judge move forward or move backward during the last iteration
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
constexpr
auto
forward_sweep
=
[
&
]()
{
...
@@ -558,10 +560,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -558,10 +560,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_sweep_
(
I0
)
=
true
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_
stride
[
I0
]
-
1
;
index_t
tmp
=
ordered_src_access_
unit
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_
stride
[
j
]
+
ordered_src_access_
stride
[
j
]
-
1
;
tmp
=
tmp
*
ordered_src_access_
unit
[
j
]
+
ordered_src_access_
unit
[
j
]
-
1
;
});
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
...
@@ -619,11 +621,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -619,11 +621,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
1
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
1
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_strides
=
ThreadClusterLengths
{}
*
(
dst_access_unit
-
dst_access_unit_helper
);
constexpr
auto
dst_access_strides
=
ThreadClusterLengths
{}
*
(
dst_access_unit
-
dst_access_unit_helper
);
#if 0
if (get_thread_local_1d_id()==0)
{
printf("dst_access_strides: %d, %d, %d\n",
dst_access_strides.At(Number<0>{}).value,
dst_access_strides.At(Number<1>{}).value,
dst_access_strides.At(Number<2>{}).value);
}
#endif
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_strides
=
constexpr
auto
ordered_dst_access_strides
=
container_reorder_given_new2old
(
dst_access_strides
,
dst_dim_access_order
);
container_reorder_given_new2old
(
dst_access_strides
,
dst_dim_access_order
);
constexpr
auto
ordered_dst_access_unit
=
container_reorder_given_new2old
(
dst_access_unit
,
dst_dim_access_order
);
// judge move forward or move backward during the last iteration
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
constexpr
auto
forward_sweep
=
[
&
]()
{
...
@@ -632,10 +645,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -632,10 +645,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
forward_sweep_
(
I0
)
=
true
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_
strides
[
I0
]
-
1
;
index_t
tmp
=
ordered_dst_access_
unit
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_
strides
[
j
]
+
ordered_dst_access_
strides
[
j
]
-
1
;
tmp
=
tmp
*
ordered_dst_access_
unit
[
j
]
+
ordered_dst_access_
unit
[
j
]
-
1
;
});
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
...
@@ -643,7 +656,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -643,7 +656,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return
forward_sweep_
;
return
forward_sweep_
;
}();
}();
#if 0
if (get_thread_local_1d_id()==0)
{
printf("forward_sweep: %d, %d, %d\n",
forward_sweep[Number<0>{}],
forward_sweep[Number<1>{}],
forward_sweep[Number<2>{}]);
}
#endif
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
constexpr
auto
dst_data_idx
=
[
&
]()
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment