Commit 3eee1b9b authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

adding tall and skinny gemm

parent 67adf1b4
......@@ -7,7 +7,7 @@
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemv.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
......@@ -19,12 +19,12 @@ namespace instance {
void add_device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemv<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
DeviceTsmm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemv<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
DeviceTsmm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
template <typename ADataType,
......@@ -34,7 +34,7 @@ template <typename ADataType,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemv<ALayout,
ck::tensor_operation::device::DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
......@@ -44,7 +44,7 @@ struct DeviceOperationInstanceFactory<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemv<ALayout,
using DeviceOp = DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment