Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
06d2c7b1
Commit
06d2c7b1
authored
Jun 28, 2023
by
Jing Zhang
Committed by
root
Jun 28, 2023
Browse files
clean
parents
b27909a0
3b18f1e3
Changes
1000
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
443 additions
and
25 deletions
+443
-25
client_example/15_gemm_add_multiply/gemm_add_multiply.cpp
client_example/15_gemm_add_multiply/gemm_add_multiply.cpp
+1
-1
client_example/15_reduce/reduce_nhwc_c.cpp
client_example/15_reduce/reduce_nhwc_c.cpp
+1
-1
client_example/16_convnd_fwd/common.hpp
client_example/16_convnd_fwd/common.hpp
+0
-4
client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp
client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp
+2
-2
client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp
client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp
+2
-2
client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
...xample/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
+1
-1
client_example/18_groupnorm/groupnorm_swish.cpp
client_example/18_groupnorm/groupnorm_swish.cpp
+29
-4
client_example/19_pool_fwd/CMakeLists.txt
client_example/19_pool_fwd/CMakeLists.txt
+5
-0
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
+199
-0
client_example/19_pool_fwd/max_pool2d_fwd.cpp
client_example/19_pool_fwd/max_pool2d_fwd.cpp
+193
-0
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+1
-1
example/01_gemm/gemm_dl_fp16.cpp
example/01_gemm/gemm_dl_fp16.cpp
+1
-1
example/01_gemm/gemm_dl_fp32.cpp
example/01_gemm/gemm_dl_fp32.cpp
+1
-1
example/01_gemm/gemm_dl_int4.cpp
example/01_gemm/gemm_dl_int4.cpp
+1
-1
example/01_gemm/gemm_dl_int8.cpp
example/01_gemm/gemm_dl_int8.cpp
+1
-1
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+1
-1
example/01_gemm/gemm_xdl_bf16.cpp
example/01_gemm/gemm_xdl_bf16.cpp
+1
-1
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+1
-1
example/01_gemm/gemm_xdl_fp64.cpp
example/01_gemm/gemm_xdl_fp64.cpp
+1
-1
example/01_gemm/gemm_xdl_int4.cpp
example/01_gemm/gemm_xdl_int4.cpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
client_example/15_gemm_add_multiply/gemm_add_multiply.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iomanip>
#include <vector>
#include <vector>
...
...
client_example/15_reduce/reduce_nhwc_c.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <functional>
#include <numeric>
#include <numeric>
...
...
client_example/16_convnd_fwd/common.hpp
View file @
06d2c7b1
...
@@ -141,14 +141,10 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD
...
@@ -141,14 +141,10 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
),
NumDimSpatial
+
1
));
std
::
next
(
rbegin
(
in_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_lengths
),
2
),
rend
(
wei_lengths
));
std
::
rotate
(
rbegin
(
wei_lengths
),
std
::
rotate
(
rbegin
(
wei_lengths
),
std
::
next
(
rbegin
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_lengths
),
NumDimSpatial
+
1
));
std
::
next
(
rbegin
(
wei_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
wei_strides
)),
std
::
next
(
rbegin
(
wei_strides
),
2
),
rend
(
wei_strides
));
std
::
rotate
(
rbegin
(
wei_strides
),
std
::
rotate
(
rbegin
(
wei_strides
),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
next
(
rbegin
(
wei_strides
),
NumDimSpatial
+
1
));
std
::
next
(
rbegin
(
wei_strides
),
NumDimSpatial
+
1
));
...
...
client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp
View file @
06d2c7b1
...
@@ -11,7 +11,7 @@ using WeiDataType = ck::half_t;
...
@@ -11,7 +11,7 @@ using WeiDataType = ck::half_t;
using
OutDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KZYX
G
C
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
G
KZYXC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
3
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
3
;
...
@@ -38,7 +38,7 @@ int main()
...
@@ -38,7 +38,7 @@ int main()
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
OutLayout
>
(
OutLayout
>
(
{
N
,
Di
,
Hi
,
Wi
,
G
,
C
},
{
K
,
Z
,
Y
,
X
,
G
,
C
},
{
N
,
Do
,
Ho
,
Wo
,
G
,
K
})
{
N
,
Di
,
Hi
,
Wi
,
G
,
C
},
{
G
,
K
,
Z
,
Y
,
X
,
C
},
{
N
,
Do
,
Ho
,
Wo
,
G
,
K
})
?
EXIT_SUCCESS
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
:
EXIT_FAILURE
;
}
}
client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp
View file @
06d2c7b1
...
@@ -11,7 +11,7 @@ using WeiDataType = float;
...
@@ -11,7 +11,7 @@ using WeiDataType = float;
using
OutDataType
=
float
;
using
OutDataType
=
float
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KZYX
G
C
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
G
KZYXC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
3
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
3
;
...
@@ -38,7 +38,7 @@ int main()
...
@@ -38,7 +38,7 @@ int main()
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
OutLayout
>
(
OutLayout
>
(
{
N
,
Di
,
Hi
,
Wi
,
G
,
C
},
{
K
,
Z
,
Y
,
X
,
G
,
C
},
{
N
,
Do
,
Ho
,
Wo
,
G
,
K
})
{
N
,
Di
,
Hi
,
Wi
,
G
,
C
},
{
G
,
K
,
Z
,
Y
,
X
,
C
},
{
N
,
Do
,
Ho
,
Wo
,
G
,
K
})
?
EXIT_SUCCESS
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
:
EXIT_FAILURE
;
}
}
client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iomanip>
#include <iostream>
#include <iostream>
...
...
client_example/18_groupnorm/groupnorm_swish.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iomanip>
#include <vector>
#include <vector>
...
@@ -72,6 +72,30 @@ int main(int argc, char* argv[])
...
@@ -72,6 +72,30 @@ int main(int argc, char* argv[])
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
const
auto
&
generic_op_ptr
=
op_ptrs
[
0
];
auto
generic_argument_ptr
=
generic_op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
G
,
C
},
// lengths
xy_strides
,
// xStrides
gamma_beta_strides
,
// gammaStrides
gamma_beta_strides
,
// betaStrides
xy_strides
,
// yStrides
{
1
,
2
,
4
},
// reduceDims
1e-6
,
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
Swish
{});
if
(
!
generic_op_ptr
->
IsSupportedArgument
(
generic_argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"The generic kernel instance should be able to support any input shapes"
);
};
std
::
string
best_op_name
;
std
::
string
best_op_name
;
bool
found
=
false
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
int
best_op_id
=
-
1
;
...
@@ -131,11 +155,12 @@ int main(int argc, char* argv[])
...
@@ -131,11 +155,12 @@ int main(int argc, char* argv[])
}
}
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
// run the best intance
if
(
found
)
{
{
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
<<
std
::
endl
;
...
...
client_example/19_pool_fwd/CMakeLists.txt
0 → 100644
View file @
06d2c7b1
add_executable
(
client_max_pool2d_fwd max_pool2d_fwd.cpp
)
target_link_libraries
(
client_max_pool2d_fwd PRIVATE composable_kernel::device_operations
)
add_executable
(
client_avg_pool3d_fwd avg_pool3d_fwd.cpp
)
target_link_libraries
(
client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations
)
\ No newline at end of file
client_example/19_pool_fwd/avg_pool3d_fwd.cpp
0 → 100644
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp"
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
IndexDataType
=
int32_t
;
constexpr
ck
::
index_t
InOutRank
=
5
;
constexpr
ck
::
index_t
WindowRank
=
3
;
#if 0
constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
constexpr bool OutputIndex = false;
#else
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
constexpr
bool
OutputIndex
=
false
;
#endif
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
ck
::
index_t
N
=
2
;
ck
::
index_t
C
=
32
;
ck
::
index_t
Z
=
2
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Di
=
30
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
window_stride_d
=
2
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_d
=
1
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_d
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
Do
=
(
Di
+
in_left_pad_d
+
in_right_pad_d
-
Z
)
/
window_stride_d
+
1
;
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
// Pool API only support the order of NCDHW
std
::
vector
<
ck
::
index_t
>
in_length
=
{
N
,
C
,
Di
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
out_length
=
{
N
,
C
,
Do
,
Ho
,
Wo
};
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
=
{
Z
,
Y
,
X
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
window_stride_d
,
window_stride_h
,
window_stride_w
};
std
::
vector
<
ck
::
index_t
>
input_left_pads
=
{
in_left_pad_d
,
in_left_pad_h
,
in_left_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_right_pads
=
{
in_right_pad_d
,
in_right_pad_h
,
in_right_pad_w
};
std
::
size_t
in_tensor_size
=
N
*
C
*
Di
*
Hi
*
Wi
;
std
::
size_t
out_tensor_size
=
N
*
C
*
Do
*
Ho
*
Wo
;
// tensor layout = NDHWC
std
::
vector
<
ck
::
index_t
>
in_tensor_stride
=
{
Di
*
C
*
Hi
*
Wi
,
1
,
C
*
Hi
*
Wi
,
Wi
*
C
,
C
};
std
::
vector
<
ck
::
index_t
>
out_tensor_stride
=
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
};
SimpleDeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_tensor_size
);
SimpleDeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_tensor_size
);
SimpleDeviceMem
out_indices_device_buf
(
sizeof
(
IndexDataType
)
*
out_tensor_size
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
ReduceOpId
,
OutputIndex
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_length
,
window_spatial_lengths
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
input_left_pads
,
input_right_pads
,
{
2
,
3
,
4
});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_bytes
=
in_tensor_size
*
sizeof
(
InDataType
)
+
out_tensor_size
*
sizeof
(
OutDataType
);
if
constexpr
(
OutputIndex
)
num_bytes
+=
out_tensor_size
*
sizeof
(
IndexDataType
);
float
gb_per_sec
=
num_bytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
// run the best intance
if
(
found
)
{
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_length
,
window_spatial_lengths
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
input_left_pads
,
input_right_pads
,
{
2
,
3
,
4
});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
client_example/19_pool_fwd/max_pool2d_fwd.cpp
0 → 100644
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp"
using
InDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
IndexDataType
=
int32_t
;
constexpr
ck
::
index_t
InOutRank
=
4
;
constexpr
ck
::
index_t
WindowRank
=
2
;
#if 1
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
constexpr
bool
OutputIndex
=
true
;
#else
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
constexpr
bool
OutputIndex
=
false
;
#endif
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
ck
::
index_t
N
=
2
;
ck
::
index_t
C
=
32
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Hi
=
30
;
ck
::
index_t
Wi
=
30
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
// Pool API only support the order of NCHW
std
::
vector
<
ck
::
index_t
>
in_length
=
{
N
,
C
,
Hi
,
Wi
};
std
::
vector
<
ck
::
index_t
>
out_length
=
{
N
,
C
,
Ho
,
Wo
};
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
=
{
Y
,
X
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
window_stride_h
,
window_stride_w
};
std
::
vector
<
ck
::
index_t
>
input_left_pads
=
{
in_left_pad_h
,
in_left_pad_w
};
std
::
vector
<
ck
::
index_t
>
input_right_pads
=
{
in_right_pad_h
,
in_right_pad_w
};
std
::
size_t
in_tensor_size
=
N
*
C
*
Hi
*
Wi
;
std
::
size_t
out_tensor_size
=
N
*
C
*
Ho
*
Wo
;
// tensor layout = NHWC
std
::
vector
<
ck
::
index_t
>
in_tensor_stride
=
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
};
std
::
vector
<
ck
::
index_t
>
out_tensor_stride
=
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
};
SimpleDeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_tensor_size
);
SimpleDeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_tensor_size
);
SimpleDeviceMem
out_indices_device_buf
(
sizeof
(
IndexDataType
)
*
out_tensor_size
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DevicePoolFwd
<
InOutRank
,
WindowRank
,
InDataType
,
OutDataType
,
IndexDataType
,
ReduceOpId
,
OutputIndex
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_length
,
window_spatial_lengths
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
input_left_pads
,
input_right_pads
,
{
2
,
3
});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_bytes
=
in_tensor_size
*
sizeof
(
InDataType
)
+
out_tensor_size
*
sizeof
(
OutDataType
);
if
constexpr
(
OutputIndex
)
num_bytes
+=
out_tensor_size
*
sizeof
(
IndexDataType
);
float
gb_per_sec
=
num_bytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
// run the best intance
if
(
found
)
{
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
out_indices_device_buf
.
GetDeviceBuffer
()),
in_length
,
window_spatial_lengths
,
out_length
,
in_tensor_stride
,
out_tensor_stride
,
out_tensor_stride
,
window_strides
,
input_left_pads
,
input_right_pads
,
{
2
,
3
});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
example/01_gemm/common.hpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
example/01_gemm/gemm_dl_fp16.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
example/01_gemm/gemm_dl_fp32.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
example/01_gemm/gemm_dl_int4.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#error Should compile this file with ck::int4_t support
...
...
example/01_gemm/gemm_dl_int8.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
example/01_gemm/gemm_wmma_fp16.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
example/01_gemm/gemm_xdl_bf16.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
example/01_gemm/gemm_xdl_fp16.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
example/01_gemm/gemm_xdl_fp64.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
example/01_gemm/gemm_xdl_int4.cpp
View file @
06d2c7b1
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#error Should compile this file with ck::int4_t support
...
...
Prev
1
2
3
4
5
6
7
…
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment