Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
22161866
Commit
22161866
authored
Feb 22, 2023
by
ltqin
Browse files
change d0 desc
parent
bef0cb20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
9 deletions
+9
-9
example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp
...s_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp
+9
-9
No files found.
example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp
View file @
22161866
...
@@ -150,10 +150,10 @@ int main(int argc, char* argv[])
...
@@ -150,10 +150,10 @@ int main(int argc, char* argv[])
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
int
G0
=
3
;
int
G0
=
64
;
int
G1
=
2
;
int
G1
=
1
2
;
int
M
=
1024
;
int
M
=
512
;
int
N
=
1024
;
int
N
=
512
;
int
K
=
64
;
int
K
=
64
;
int
O
=
64
;
int
O
=
64
;
float
alpha
=
1
;
float
alpha
=
1
;
...
@@ -194,12 +194,11 @@ int main(int argc, char* argv[])
...
@@ -194,12 +194,11 @@ int main(int argc, char* argv[])
}
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
M
*
G1
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, M, G1, K]
M
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
{
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B0 layout [G0, N, G1, K]
N
*
G1
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, N, G1, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
{
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
{
...
@@ -211,7 +210,7 @@ int main(int argc, char* argv[])
...
@@ -211,7 +210,7 @@ int main(int argc, char* argv[])
// D layout [G0, M, G1, N]
// D layout [G0, M, G1, N]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
{
G1
*
N
,
N
,
0
,
1
};
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
@@ -224,6 +223,7 @@ int main(int argc, char* argv[])
...
@@ -224,6 +223,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"b0_gs_ns_ks: "
<<
b0_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_gs_ns_ks: "
<<
b0_gs_ns_ks
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_gs_os_ns: "
<<
b1_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b1_gs_os_ns: "
<<
b1_gs_os_ns
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_gs_ms_os: "
<<
c_gs_ms_os_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_gs_ms_ns: "
<<
d0_gs_ms_ns
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -255,7 +255,7 @@ int main(int argc, char* argv[])
...
@@ -255,7 +255,7 @@ int main(int argc, char* argv[])
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
G0
*
G1
*
M
*
K
);
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
G0
*
G1
*
M
*
K
);
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
G0
*
G1
*
N
*
K
);
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
G0
*
G1
*
N
*
K
);
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
G0
*
G1
*
M
*
N
);
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
G0
*
G1
*
N
);
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
G0
*
G1
*
O
*
N
);
DeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
G0
*
G1
*
O
*
N
);
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
G0
*
G1
*
M
*
O
);
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
G0
*
G1
*
M
*
O
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment