Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ed4662b5
"vscode:/vscode.git/clone" did not exist on "e9eb0938f4f401bb697fcdd2dfafccf9e149d647"
Commit
ed4662b5
authored
Aug 25, 2022
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/gemm_int4_examples
parents
cb8f87cb
e1a3fff6
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
371 additions
and
261 deletions
+371
-261
.gitignore
.gitignore
+1
-0
client_example/05_layernorm/CMakeLists.txt
client_example/05_layernorm/CMakeLists.txt
+2
-0
client_example/05_layernorm/layernorm2d.cpp
client_example/05_layernorm/layernorm2d.cpp
+159
-0
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+1
-0
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+0
-3
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp
...peration/gpu/device/device_gemm_multiple_d_multiple_r.hpp
+14
-2
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+43
-234
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+62
-20
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+2
-0
library/include/ck/library/tensor_operation_instance/gpu/layernorm.hpp
...de/ck/library/tensor_operation_instance/gpu/layernorm.hpp
+85
-0
No files found.
.gitignore
View file @
ed4662b5
...
@@ -46,3 +46,4 @@ build*
...
@@ -46,3 +46,4 @@ build*
# GDB temporary files
# GDB temporary files
.gdb_history
.gdb_history
install.dir*
client_example/05_layernorm/CMakeLists.txt
0 → 100644
View file @
ed4662b5
add_executable
(
client_layernorm2d layernorm2d.cpp
)
target_link_libraries
(
client_layernorm2d PRIVATE composable_kernel::device_operations
)
client_example/05_layernorm/layernorm2d.cpp
0 → 100644
View file @
ed4662b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/layernorm.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
1024
;
auto
xy_size
=
(
M
-
1
)
*
Stride
+
N
;
SimpleDeviceMem
x_device_buf
(
sizeof
(
XDataType
)
*
xy_size
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
N
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
N
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
xy_size
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceLayernorm
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
,
PassThrough
,
Rank
,
NumReduceDim
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
{
Stride
,
1
},
// xStrides
{
1
},
// gammaStrides
{
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
1
},
// reduceDims
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
M
*
N
+
sizeof
(
GammaDataType
)
*
N
+
sizeof
(
BetaDataType
)
*
N
+
sizeof
(
YDataType
)
*
M
*
N
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
{
Stride
,
1
},
// xStrides
{
1
},
// gammaStrides
{
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
1
},
// reduceDims
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
client_example/CMakeLists.txt
View file @
ed4662b5
...
@@ -10,3 +10,4 @@ add_subdirectory(01_gemm)
...
@@ -10,3 +10,4 @@ add_subdirectory(01_gemm)
add_subdirectory
(
02_gemm_add_add_fastgelu
)
add_subdirectory
(
02_gemm_add_add_fastgelu
)
add_subdirectory
(
03_gemm_layernorm
)
add_subdirectory
(
03_gemm_layernorm
)
add_subdirectory
(
04_contraction
)
add_subdirectory
(
04_contraction
)
add_subdirectory
(
05_layernorm
)
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
ed4662b5
...
@@ -3,9 +3,6 @@
...
@@ -3,9 +3,6 @@
#pragma once
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
ed4662b5
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <array>
#include <array>
#include "device_base.hpp"
#include "
ck/tensor_operation/gpu/device/
device_base.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp
View file @
ed4662b5
...
@@ -3,15 +3,27 @@
...
@@ -3,15 +3,27 @@
#pragma once
#pragma once
#include <
iostream
>
#include <
array
>
#include "device_base.hpp"
#include "
ck/tensor_operation/gpu/device/
device_base.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
// FIXME: DeviceGemmReduce type need to well define the problem
// FIXME: DeviceGemmReduce type need to well define the problem
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : R0[M], R1[M], ...
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
DELayout
,
typename
DELayout
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
ed4662b5
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -192,7 +193,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -192,7 +193,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
auto
MakeAGridDescriptor_M_K
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
@@ -207,95 +211,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -207,95 +211,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
}
static
auto
MakeBGridDescriptor_
BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_
N_K
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
@@ -310,92 +229,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -310,92 +229,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}
}();
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
}
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
...
@@ -413,47 +247,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -413,47 +247,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
e_grid_desc_mraw_nraw
;
}
}
}
// assume D is packed tensor
// assume D is packed tensor
...
@@ -482,10 +276,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -482,10 +276,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
}
}
}
}
using
AGridDesc_
AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_
AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_
M_K
=
decltype
(
MakeAGridDescriptor_
M_K
(
1
,
1
,
1
));
using
BGridDesc_
BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_
BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_
N_K
=
decltype
(
MakeBGridDescriptor_
N_K
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
1
,
1
,
1
));
using
RGridDesc_M
=
decltype
(
MakeRGridDescriptor_M
(
1
));
using
RGridDesc_M
=
decltype
(
MakeRGridDescriptor_M
(
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
using
GridwiseGemm
=
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
...
@@ -504,8 +298,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -504,8 +298,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
ThreadReduceOperations
,
ThreadReduceOperations
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
RsGlobalMemoryDataOperation
,
RsGlobalMemoryDataOperation
,
AGridDesc_
AK0_M_AK1
,
AGridDesc_
M_K
,
BGridDesc_
BK0_N_BK1
,
BGridDesc_
N_K
,
EGridDesc_M_N
,
EGridDesc_M_N
,
RGridDesc_M
,
RGridDesc_M
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
...
@@ -542,6 +336,13 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -542,6 +336,13 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock
,
RThreadTransferDstScalarPerVector_MPerBlock
,
LoopSched
>
;
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -567,12 +368,16 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -567,12 +368,16 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
p_ds_grid_
{},
// FIXME
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_rs_grid_
{},
// FIXME
p_rs_grid_
{},
// FIXME
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideE
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideE
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
r_grid_desc_m_
{
DeviceOp
::
MakeRGridDescriptor_M
(
MRaw
)},
r_grid_desc_m_
{
DeviceOp
::
MakeRGridDescriptor_M
(
MRaw
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
rs_grid_desc_mblock_mperblock_
{},
rs_grid_desc_mblock_mperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -581,8 +386,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -581,8 +386,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
qs_element_op_
{
qs_element_op
},
qs_element_op_
{
qs_element_op
},
rs_element_op_
{
rs_element_op
}
rs_element_op_
{
rs_element_op
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
ak0_m_ak1
_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
m_k
_
,
b_grid_desc_
bk0_n_bk1
_
,
b_grid_desc_
n_k
_
,
e_grid_desc_m_n_
,
e_grid_desc_m_n_
,
r_grid_desc_m_
,
r_grid_desc_m_
,
block_2_etile_map_
))
block_2_etile_map_
))
...
@@ -624,6 +429,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -624,6 +429,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
typename
GridwiseGemm
::
RsGridPointer
p_rs_grid_
;
typename
GridwiseGemm
::
RsGridPointer
p_rs_grid_
;
// tensor descriptors
// tensor descriptors
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
RGridDesc_M
r_grid_desc_m_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
StaticallyIndexedArray
<
...
@@ -631,16 +442,14 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -631,16 +442,14 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
NumDTensor
>
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
RGridDesc_M
r_grid_desc_m_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
RGridDescriptor_MBlock_MPerBlock
,
NumRTensor
>
StaticallyIndexedArray
<
typename
GridwiseGemm
::
RGridDescriptor_MBlock_MPerBlock
,
NumRTensor
>
rs_grid_desc_mblock_mperblock_
;
rs_grid_desc_mblock_mperblock_
;
// block-to-e-tile map
// block-to-e-tile map
typename
GridwiseGemm
::
Default
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
...
@@ -657,8 +466,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -657,8 +466,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
ak0_m_ak1
_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
m_k
_
,
arg
.
b_grid_desc_
bk0_n_bk1
_
,
arg
.
b_grid_desc_
n_k
_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
r_grid_desc_m_
,
arg
.
r_grid_desc_m_
,
arg
.
block_2_etile_map_
))
arg
.
block_2_etile_map_
))
...
@@ -750,8 +559,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -750,8 +559,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
return
false
;
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
ak0_m_ak1
_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
m_k
_
,
arg
.
b_grid_desc_
bk0_n_bk1
_
,
arg
.
b_grid_desc_
n_k
_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
r_grid_desc_m_
,
arg
.
r_grid_desc_m_
,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
ed4662b5
...
@@ -32,8 +32,8 @@ template <typename FloatAB,
...
@@ -32,8 +32,8 @@ template <typename FloatAB,
typename
ThreadReduceOperations
,
typename
ThreadReduceOperations
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
RsGlobalMemoryDataOperation
,
typename
RsGlobalMemoryDataOperation
,
typename
AGridDesc_
AK0_M_AK1
,
typename
AGridDesc_
M_K
,
typename
BGridDesc_
BK0_N_BK1
,
typename
BGridDesc_
N_K
,
typename
EGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
RGridDesc_M
,
typename
RGridDesc_M
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -84,10 +84,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
AK
0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK
0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK
1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK
0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK
1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK
0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -97,7 +97,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0
PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
}
...
@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -105,7 +105,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
}
...
@@ -167,22 +167,57 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -167,22 +167,57 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_size
*
sizeof
(
FloatCShuffle
));
c_block_size
*
sizeof
(
FloatCShuffle
));
}
}
// A desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// B desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
RGridDesc_M
&
r_grid_desc_m
,
const
RGridDesc_M
&
r_grid_desc_m
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
static_assert
(
AGridDesc_M_K
::
GetNumOfDimension
()
==
2
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
static_assert
(
BGridDesc_N_K
::
GetNumOfDimension
()
==
2
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
static_assert
(
EGridDesc_M_N
::
GetNumOfDimension
()
==
2
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
return
false
;
...
@@ -259,6 +294,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -259,6 +294,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
e_grid_desc_m_n
);
e_grid_desc_m_n
);
}
}
using
DefaultAGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
...
@@ -272,7 +311,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -272,7 +311,10 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DsGridPointer
=
decltype
(
MakeTsGridPointer
<
DsDataType
,
true
>
());
using
DsGridPointer
=
decltype
(
MakeTsGridPointer
<
DsDataType
,
true
>
());
using
RsGridPointer
=
decltype
(
MakeTsGridPointer
<
RsDataType
,
false
>
());
using
RsGridPointer
=
decltype
(
MakeTsGridPointer
<
RsDataType
,
false
>
());
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
@@ -356,7 +398,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -356,7 +398,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
...
@@ -387,7 +429,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -387,7 +429,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
...
...
include/ck/utility/reduction_operator.hpp
View file @
ed4662b5
...
@@ -79,7 +79,7 @@ struct SquaredAdd
...
@@ -79,7 +79,7 @@ struct SquaredAdd
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the
Max
accumulator!"
);
"The data type is not supported by the
SquaredAdd
accumulator!"
);
a
=
a
+
b
*
b
;
a
=
a
+
b
*
b
;
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
ed4662b5
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#pragma once
#pragma once
#include <cstdlib>
#include <cstdlib>
#include <memory>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...
...
library/include/ck/library/tensor_operation_instance/gpu/layernorm.hpp
0 → 100644
View file @
ed4662b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_layernorm_f16_rank2_instances
(
std
::
vector
<
DeviceLayernormPtr
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>&
);
void
add_device_layernorm_f16_rank4_instances
(
std
::
vector
<
DeviceLayernormPtr
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
4
,
3
>>&
);
void
add_device_layernorm_f32_rank2_instances
(
std
::
vector
<
DeviceLayernormPtr
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
2
,
1
>>&
);
void
add_device_layernorm_f32_rank4_instances
(
std
::
vector
<
DeviceLayernormPtr
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>&
);
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceLayernorm
<
XDataType
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceLayernorm
<
XDataType
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
add_device_layernorm_f16_rank2_instances
(
op_ptrs
);
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
add_device_layernorm_f16_rank4_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
add_device_layernorm_f32_rank2_instances
(
op_ptrs
);
else
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
)
add_device_layernorm_f32_rank4_instances
(
op_ptrs
);
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment