Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
987eff1b
Commit
987eff1b
authored
Sep 21, 2022
by
Jing Zhang
Browse files
fixed GetDsPtrOffset/GetEPtrOffset with long_index
parent
5d7ab929
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
16 deletions
+67
-16
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
+61
-11
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...ce/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+6
-5
No files found.
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
View file @
987eff1b
...
...
@@ -228,7 +228,7 @@ int main(int argc, char* argv[])
// E[G, N0, M0, N1, M1, N2]
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G
,
M0
,
M1
,
N0
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
M
0
*
M
1
*
N
0
*
N
1
*
N2
,
N1
*
M1
*
N2
,
N2
,
M0
*
N1
*
M1
*
N2
,
M1
*
N2
,
1
};
N
0
*
M
0
*
N
1
*
M
1
*
N2
,
N1
*
M1
*
N2
,
N2
,
M0
*
N1
*
M1
*
N2
,
M1
*
N2
,
1
};
if
(
argc
==
1
)
{
...
...
@@ -257,9 +257,6 @@ int main(int argc, char* argv[])
Tensor
<
DDataType
>
d_gs_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_lengths
.
begin
(),
d_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
d_gs_ms_ns_strides
.
begin
(),
d_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
Tensor
<
EDataType
>
e_gs_ms_ns_device_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
...
...
@@ -267,7 +264,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"a_gs_ms_ks: "
<<
a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_gs_ns_ks: "
<<
b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d_gs_ms_ns: "
<<
d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_
host
_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_gs_ms_ns: "
<<
e_gs_ms_ns_
device
_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -359,9 +356,26 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
Tensor
<
CShuffleDataType
>
c_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_lengths
.
begin
(),
e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
e_gs_ms_ns_strides
.
begin
(),
e_gs_ms_ns_strides
.
end
()));
const
ck
::
index_t
G_
=
1
;
const
ck
::
index_t
N0_
=
3
;
// A[G, M0, M1, K0]
std
::
vector
<
ck
::
index_t
>
host_a_gs_ms_ks_lengths
{
G_
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
host_a_gs_ms_ks_strides
{
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[G, N0_, N1, N2, K0]
std
::
vector
<
ck
::
index_t
>
host_b_gs_ns_ks_lengths
{
G_
,
N0_
,
N1
,
N2
,
K0
};
std
::
vector
<
ck
::
index_t
>
host_b_gs_ns_ks_strides
{
N0_
*
N1
*
N2
*
K0
,
N1
*
N2
*
K0
,
N2
*
K0
,
K0
,
1
};
// D[G_, N0_, M0, N1, M1, N2]
std
::
vector
<
ck
::
index_t
>
host_d_gs_ms_ns_lengths
{
G_
,
M0
,
M1
,
N0_
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
host_d_gs_ms_ns_strides
{
N0_
*
N1
*
N2
,
0
,
0
,
N1
*
N2
,
N2
,
1
};
// E[G_, N0_, M0, N1, M1, N2]
std
::
vector
<
ck
::
index_t
>
host_e_gs_ms_ns_lengths
{
G_
,
M0
,
M1
,
N0_
,
N1
,
N2
};
std
::
vector
<
ck
::
index_t
>
host_e_gs_ms_ns_strides
{
N0_
*
M0
*
N1
*
M1
*
N2
,
N1
*
M1
*
N2
,
N2
,
M0
*
N1
*
M1
*
N2
,
M1
*
N2
,
1
};
using
ReferenceOpInstance
=
ReferenceContraction_G1_M2_N3_K1
<
NumDimM
,
NumDimN
,
...
...
@@ -377,8 +391,44 @@ int main(int argc, char* argv[])
auto
ref_gemm
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_gs_ms_ks
,
b_gs_ns_ks
,
Tensor
<
ADataType
>
host_a_gs_ms_ks
(
std
::
vector
<
std
::
size_t
>
(
host_a_gs_ms_ks_lengths
.
begin
(),
host_a_gs_ms_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
host_a_gs_ms_ks_strides
.
begin
(),
host_a_gs_ms_ks_strides
.
end
()));
Tensor
<
BDataType
>
host_b_gs_ns_ks
(
std
::
vector
<
std
::
size_t
>
(
host_b_gs_ns_ks_lengths
.
begin
(),
host_b_gs_ns_ks_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
host_b_gs_ns_ks_strides
.
begin
(),
host_b_gs_ns_ks_strides
.
end
()));
Tensor
<
DDataType
>
host_d_gs_ms_ns
(
std
::
vector
<
std
::
size_t
>
(
host_d_gs_ms_ns_lengths
.
begin
(),
host_d_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
host_d_gs_ms_ns_strides
.
begin
(),
host_d_gs_ms_ns_strides
.
end
()));
std
::
copy
(
a_gs_ms_ks
.
mData
.
begin
(),
a_gs_ms_ks
.
mData
.
end
(),
host_a_gs_ms_ks
.
begin
());
std
::
copy
(
b_gs_ns_ks
.
mData
.
begin
(),
b_gs_ns_ks
.
mData
.
end
(),
host_b_gs_ns_ks
.
begin
());
std
::
copy
(
d_gs_ms_ns
.
mData
.
begin
(),
d_gs_ms_ns
.
mData
.
end
(),
host_d_gs_ms_ns
.
begin
());
Tensor
<
EDataType
>
e_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
host_e_gs_ms_ns_lengths
.
begin
(),
host_e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
host_e_gs_ms_ns_strides
.
begin
(),
host_e_gs_ms_ns_strides
.
end
()));
std
::
cout
<<
"host_a_gs_ms_ks: "
<<
host_a_gs_ms_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"host_b_gs_ns_ks: "
<<
host_b_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"host_d_gs_ms_ns: "
<<
host_d_gs_ms_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"host_e_gs_ms_ns: "
<<
e_gs_ms_ns_host_result
.
mDesc
<<
std
::
endl
;
Tensor
<
CShuffleDataType
>
c_gs_ms_ns_host_result
(
std
::
vector
<
std
::
size_t
>
(
host_e_gs_ms_ns_lengths
.
begin
(),
host_e_gs_ms_ns_lengths
.
end
()),
std
::
vector
<
std
::
size_t
>
(
host_e_gs_ms_ns_strides
.
begin
(),
host_e_gs_ms_ns_strides
.
end
()));
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
host_a_gs_ms_ks
,
host_b_gs_ns_ks
,
c_gs_ms_ns_host_result
,
a_element_op
,
b_element_op
,
...
...
@@ -401,7 +451,7 @@ int main(int argc, char* argv[])
{
cde_element_op
(
e_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
c_gs_ms_ns_host_result
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
),
d_gs_ms_ns
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
));
host_
d_gs_ms_ns
(
g0
,
m0
,
m1
,
n0
,
n1
,
n2
));
}
}
}
...
...
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
987eff1b
...
...
@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_A_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_A_
;
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_B_
)
;
return
static_cast
<
long_index_t
>
(
g_idx
)
*
batch_stride_B_
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
...
...
@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
ds_offset
[
i
]
=
static_cast
<
long_index_t
>
(
g_idx
)
*
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
});
return
ds_offset
;
...
...
@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
return
static_cast
<
long_index_t
>
(
g_idx
)
*
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
1
,
0
,
0
));
}
private:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment