Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2baf0613
Commit
2baf0613
authored
Oct 20, 2023
by
Jing Zhang
Browse files
clean code: add multiA into example
parent
b164b0ef
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
69 deletions
+87
-69
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+1
-3
example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
...ed_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
+65
-26
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+21
-40
No files found.
example/15_grouped_gemm/CMakeLists.txt
View file @
2baf0613
...
@@ -32,6 +32,4 @@ if(USE_BITINT_EXTENSION_INT4)
...
@@ -32,6 +32,4 @@ if(USE_BITINT_EXTENSION_INT4)
endif
()
endif
()
add_example_executable
(
example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
)
if
(
result EQUAL 0
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16
)
add_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16
)
endif
()
example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp
View file @
2baf0613
...
@@ -33,8 +33,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -33,8 +33,9 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
A0DataType
=
F16
;
using
A0DataType
=
F16
;
using
A1DataType
=
F32
;
using
AsDataType
=
ck
::
Tuple
<
A0DataType
,
A1DataType
>
;
using
B0DataType
=
F16
;
using
B0DataType
=
F16
;
using
AsDataType
=
ck
::
Tuple
<
A0DataType
>
;
using
BsDataType
=
ck
::
Tuple
<
B0DataType
>
;
using
BsDataType
=
ck
::
Tuple
<
B0DataType
>
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
...
@@ -43,14 +44,26 @@ using DsDataType = ck::Tuple<D0DataType>;
...
@@ -43,14 +44,26 @@ using DsDataType = ck::Tuple<D0DataType>;
using
EDataType
=
F32
;
using
EDataType
=
F32
;
using
A0Layout
=
Row
;
using
A0Layout
=
Row
;
using
A1Layout
=
Row
;
using
AsLayout
=
ck
::
Tuple
<
A0Layout
,
A1Layout
>
;
using
B0Layout
=
Col
;
using
B0Layout
=
Col
;
using
AsLayout
=
ck
::
Tuple
<
A0Layout
>
;
using
BsLayout
=
ck
::
Tuple
<
B0Layout
>
;
using
BsLayout
=
ck
::
Tuple
<
B0Layout
>
;
using
D0Layout
=
Row
;
using
D0Layout
=
Row
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
struct
AddScale
{
__host__
__device__
constexpr
void
operator
()(
ck
::
half_t
&
a
,
const
ck
::
half_t
&
a0
,
const
float
&
a1
)
const
{
a
=
scale
*
(
a0
+
a1
);
}
float
scale
=
1.0
;
};
using
AElementOp
=
AddScale
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Add
;
using
CDEElementOp
=
Add
;
...
@@ -113,13 +126,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -113,13 +126,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}
};
};
std
::
vector
<
Tensor
<
A0DataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
A0DataType
>>
a0_tensors
;
std
::
vector
<
Tensor
<
A1DataType
>>
a1_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b_tensors
;
std
::
vector
<
Tensor
<
D0DataType
>>
d0_tensors
;
std
::
vector
<
Tensor
<
D0DataType
>>
d0_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_device_tensors
;
std
::
vector
<
Tensor
<
EDataType
>>
c_device_tensors
;
a_tensors
.
reserve
(
group_count
);
a0_tensors
.
reserve
(
group_count
);
a1_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
b_tensors
.
reserve
(
group_count
);
d0_tensors
.
reserve
(
group_count
);
d0_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
c_host_tensors
.
reserve
(
group_count
);
...
@@ -127,10 +142,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -127,10 +142,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
std
::
vector
<
DeviceMemPtr
>
a_tensors_device
,
b
_tensors_device
,
d0
_tensors_device
,
std
::
vector
<
DeviceMemPtr
>
a
0
_tensors_device
,
a1
_tensors_device
,
b
_tensors_device
,
c_tensors_device
;
d0_tensors_device
,
c_tensors_device
;
a_tensors_device
.
reserve
(
group_count
);
a0_tensors_device
.
reserve
(
group_count
);
a1_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
b_tensors_device
.
reserve
(
group_count
);
d0_tensors_device
.
reserve
(
group_count
);
d0_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
c_tensors_device
.
reserve
(
group_count
);
...
@@ -140,8 +156,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -140,8 +156,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
sum_of_m
+=
problem_size
.
Ms
[
i
];
sum_of_m
+=
problem_size
.
Ms
[
i
];
a_tensors
.
push_back
(
Tensor
<
A0DataType
>
(
f_host_tensor_descriptor
(
a
0
_tensors
.
push_back
(
Tensor
<
A0DataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
A0Layout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
A0Layout
{})));
a1_tensors
.
push_back
(
Tensor
<
A1DataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
stride_As
[
i
],
A1Layout
{})));
b_tensors
.
push_back
(
Tensor
<
B0DataType
>
(
f_host_tensor_descriptor
(
b_tensors
.
push_back
(
Tensor
<
B0DataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ks
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Bs
[
i
],
B0Layout
{})));
problem_size
.
Ks
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Bs
[
i
],
B0Layout
{})));
d0_tensors
.
push_back
(
Tensor
<
D0DataType
>
(
d0_tensors
.
push_back
(
Tensor
<
D0DataType
>
(
...
@@ -150,12 +168,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -150,12 +168,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
c_device_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
c_device_tensors
.
push_back
(
Tensor
<
EDataType
>
(
f_host_tensor_descriptor
(
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
stride_Cs
[
i
],
ELayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a
0
_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" d_m_n: "
<<
d0_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" d_m_n: "
<<
d0_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]
*
problem_size
.
Ns
[
i
];
flop
+=
std
::
size_t
(
2
)
*
problem_size
.
Ms
[
i
]
*
problem_size
.
Ks
[
i
]
*
problem_size
.
Ns
[
i
];
num_btype
+=
sizeof
(
A0DataType
)
*
a_tensors
[
i
].
mDesc
.
GetElementSize
()
+
num_btype
+=
sizeof
(
A0DataType
)
*
a0_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
A1DataType
)
*
a1_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
B0DataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
B0DataType
)
*
b_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
D0DataType
)
*
d0_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
D0DataType
)
*
d0_tensors
[
i
].
mDesc
.
GetElementSize
()
+
sizeof
(
EDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
sizeof
(
EDataType
)
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
();
...
@@ -164,15 +183,18 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -164,15 +183,18 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
A1DataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
break
;
break
;
case
2
:
case
2
:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
A1DataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
break
;
break
;
default:
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a1_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
}
}
...
@@ -180,16 +202,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -180,16 +202,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}
using
GroupedGemmKernelArgument
=
using
GroupedGemmKernelArgument
=
ck
::
tensor_operation
::
device
::
GroupedGemmMultiABDKernelArgument
<
1
,
1
,
1
>
;
ck
::
tensor_operation
::
device
::
GroupedGemmMultiABDKernelArgument
<
2
,
1
,
1
>
;
std
::
vector
<
GroupedGemmKernelArgument
>
grouped_gemm_kernel_args_
;
std
::
vector
<
GroupedGemmKernelArgument
>
grouped_gemm_kernel_args_
;
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
grouped_gemm_kernel_args_
.
reserve
(
group_count
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
a_tensors_device
.
emplace_back
(
a
0
_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
A0DataType
)
*
sum_of_m
*
problem_size
.
Ks
[
i
]));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
A0DataType
)
*
sum_of_m
*
problem_size
.
Ks
[
i
]));
a1_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
A1DataType
)
*
sum_of_m
*
problem_size
.
Ks
[
i
]));
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
b_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
B0DataType
)
*
problem_size
.
Ns
[
i
]
*
problem_size
.
Ks
[
i
]));
sizeof
(
B0DataType
)
*
problem_size
.
Ns
[
i
]
*
problem_size
.
Ks
[
i
]));
...
@@ -199,9 +224,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -199,9 +224,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_tensors_device
.
emplace_back
(
c_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
sum_of_m
*
problem_size
.
Ns
[
i
]));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
EDataType
)
*
sum_of_m
*
problem_size
.
Ns
[
i
]));
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
].
mData
.
data
(),
a
0
_tensors_device
[
i
]
->
ToDevice
(
a
0
_tensors
[
i
].
mData
.
data
(),
a
_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
a0
_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
A0DataType
));
sizeof
(
A0DataType
));
a1_tensors_device
[
i
]
->
ToDevice
(
a1_tensors
[
i
].
mData
.
data
(),
a1_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
A1DataType
));
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
(),
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
].
mData
.
data
(),
b_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
b_tensors
[
i
].
mDesc
.
GetElementSpaceSize
()
*
sizeof
(
B0DataType
));
sizeof
(
B0DataType
));
...
@@ -211,20 +240,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -211,20 +240,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_descs
.
push_back
({
sum_of_m
,
gemm_descs
.
push_back
({
sum_of_m
,
problem_size
.
Ns
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
Ks
[
i
],
{
1
},
{
1
,
1
},
{
problem_size
.
stride_Bs
[
i
]},
{
problem_size
.
stride_Bs
[
i
]},
{
0
},
{
0
},
1
});
1
});
grouped_gemm_kernel_args_
.
push_back
(
grouped_gemm_kernel_args_
.
push_back
(
{
std
::
array
<
const
void
*
,
1
>
{
a_tensors_device
[
i
]
->
GetDeviceBuffer
()},
{
std
::
array
<
const
void
*
,
2
>
{
a0_tensors_device
[
i
]
->
GetDeviceBuffer
(),
a1_tensors_device
[
i
]
->
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
b_tensors_device
[
i
]
->
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
b_tensors_device
[
i
]
->
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
d0_tensors_device
[
i
]
->
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
d0_tensors_device
[
i
]
->
GetDeviceBuffer
()},
c_tensors_device
[
i
]
->
GetDeviceBuffer
(),
c_tensors_device
[
i
]
->
GetDeviceBuffer
(),
problem_size
.
Ms
[
i
],
problem_size
.
Ms
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ns
[
i
],
problem_size
.
Ks
[
i
],
problem_size
.
Ks
[
i
],
std
::
array
<
ck
::
index_t
,
1
>
{
problem_size
.
stride_As
[
i
]},
std
::
array
<
ck
::
index_t
,
2
>
{
problem_size
.
stride_As
[
i
],
problem_size
.
stride_As
[
i
]},
std
::
array
<
ck
::
index_t
,
1
>
{
problem_size
.
stride_Bs
[
i
]},
std
::
array
<
ck
::
index_t
,
1
>
{
problem_size
.
stride_Bs
[
i
]},
std
::
array
<
ck
::
index_t
,
1
>
{
0
},
std
::
array
<
ck
::
index_t
,
1
>
{
0
},
problem_size
.
stride_Cs
[
i
]});
problem_size
.
stride_Cs
[
i
]});
...
@@ -237,7 +267,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -237,7 +267,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_As
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
2
>>
p_As
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_Bs
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_Bs
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_Ds
=
{};
std
::
vector
<
std
::
array
<
const
void
*
,
1
>>
p_Ds
=
{};
std
::
vector
<
void
*>
p_Cs
=
{};
std
::
vector
<
void
*>
p_Cs
=
{};
...
@@ -281,16 +311,25 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -281,16 +311,25 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
bool
pass
=
true
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
A0DataType
,
B0DataType
,
B0DataType
,
EDataType
,
EDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
PassThrough
,
BElementOp
,
BElementOp
,
PassThrough
>
;
PassThrough
>
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
{
for
(
int
m
=
0
;
m
<
problem_size
.
Ms
[
i
];
++
m
)
{
for
(
int
k
=
0
;
k
<
problem_size
.
Ks
[
i
];
++
k
)
{
a_element_op
(
a0_tensors
[
i
](
m
,
k
),
a0_tensors
[
i
](
m
,
k
),
a1_tensors
[
i
](
m
,
k
));
}
}
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
(),
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
(),
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()
*
c_device_tensors
[
i
].
mDesc
.
GetElementSize
()
*
sizeof
(
EDataType
));
sizeof
(
EDataType
));
...
@@ -298,10 +337,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -298,10 +337,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_tensors
[
i
],
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a
0
_tensors
[
i
],
b_tensors
[
i
],
b_tensors
[
i
],
c_host_tensors
[
i
],
c_host_tensors
[
i
],
a_element_op
,
PassThrough
{}
,
b_element_op
,
b_element_op
,
PassThrough
{});
PassThrough
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
2baf0613
...
@@ -93,14 +93,6 @@ __global__ void
...
@@ -93,14 +93,6 @@ __global__ void
typename
GridwiseGemm
::
BsGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
BsGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
// constexpr auto I0 = Number<0>{};
// using AsDataType = remove_cvref_t<decltype(p_as_grid_(I0))>;
// p_as_grid_(I0) = static_cast<AsDataType>(gemm_desc_ptr[group_id].p_a_grid);
// using BsDataType = remove_cvref_t<decltype(p_bs_grid_(I0))>;
// p_bs_grid_(I0) = static_cast<BsDataType>(gemm_desc_ptr[group_id].p_b_grid);
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
using
ADataType
=
remove_cvref_t
<
decltype
(
p_as_grid_
(
i
))
>
;
using
ADataType
=
remove_cvref_t
<
decltype
(
p_as_grid_
(
i
))
>
;
p_as_grid_
(
i
)
=
static_cast
<
ADataType
>
(
gemm_desc_ptr
[
group_id
].
p_as_grid
[
i
]);
p_as_grid_
(
i
)
=
static_cast
<
ADataType
>
(
gemm_desc_ptr
[
group_id
].
p_as_grid
[
i
]);
...
@@ -500,35 +492,32 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
...
@@ -500,35 +492,32 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
const
index_t
StrideE
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
StrideE
=
gemm_descs
[
i
].
stride_C_
;
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
j
)
{
if
(
gemm_descs
[
i
].
stride_As_
.
size
()
!=
NumATensor
)
if
(
gemm_descs
[
i
].
stride_As_
.
size
()
!=
NumATensor
)
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"
);
"wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"
);
}
}
StrideAs
[
j
]
=
gemm_descs
[
i
].
stride_As_
[
j
];
static_for
<
0
,
NumATensor
,
1
>
{}(
});
[
&
](
auto
j
)
{
StrideAs
[
j
]
=
gemm_descs
[
i
].
stride_As_
[
j
];
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
j
)
{
if
(
gemm_descs
[
i
].
stride_Bs_
.
size
()
!=
NumBTensor
)
if
(
gemm_descs
[
i
].
stride_Bs_
.
size
()
!=
NumBTensor
)
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"
);
"wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"
);
}
}
StrideBs
[
j
]
=
gemm_descs
[
i
].
stride_Bs_
[
j
];
static_for
<
0
,
NumBTensor
,
1
>
{}(
});
[
&
](
auto
j
)
{
StrideBs
[
j
]
=
gemm_descs
[
i
].
stride_Bs_
[
j
];
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
if
(
gemm_descs
[
i
].
stride_Ds_
.
size
()
!=
NumDTensor
)
if
(
gemm_descs
[
i
].
stride_Ds_
.
size
()
!=
NumDTensor
)
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"
);
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"
);
}
}
StrideDs
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
static_for
<
0
,
NumDTensor
,
1
>
{}(
});
[
&
](
auto
j
)
{
StrideDs
[
j
]
=
gemm_descs
[
i
].
stride_Ds_
[
j
];
});
const
auto
e_grid_desc_m_n
=
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
...
@@ -552,14 +541,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
...
@@ -552,14 +541,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
throw
std
::
runtime_error
(
"wrong! block_2_etile_map validation failed"
);
throw
std
::
runtime_error
(
"wrong! block_2_etile_map validation failed"
);
}
}
// if(!GridwiseGemm::
// template CheckValidity<AsLayout, BsLayout, DsLayout, ELayout, GemmSpec>(
// AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1))
//{
// throw std::runtime_error(
//"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
//}
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
p_as_grid
,
p_as_grid
,
p_bs_grid
,
p_bs_grid
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment