Commit 3eee1b9b authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

adding tall and skinny gemm

parent 67adf1b4
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <memory> #include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemv.hpp" #include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...@@ -19,12 +19,12 @@ namespace instance { ...@@ -19,12 +19,12 @@ namespace instance {
void add_device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances( void add_device_gemv_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemv<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceTsmm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances( void add_device_gemv_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemv<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceTsmm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
template <typename ADataType, template <typename ADataType,
...@@ -34,7 +34,7 @@ template <typename ADataType, ...@@ -34,7 +34,7 @@ template <typename ADataType,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemv<ALayout, ck::tensor_operation::device::DeviceTsmm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
...@@ -44,7 +44,7 @@ struct DeviceOperationInstanceFactory< ...@@ -44,7 +44,7 @@ struct DeviceOperationInstanceFactory<
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>> ck::tensor_operation::element_wise::PassThrough>>
{ {
using DeviceOp = DeviceGemv<ALayout, using DeviceOp = DeviceTsmm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, ADataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment