Unverified Commit 6eef0755 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

fix wmma gemm int8; add grouped conv int8 example (#716)

parent 1344a0f2
......@@ -20,4 +20,5 @@ if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS M
endif()
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
// kernel data types
using InKernelDataType = I8;
using WeiKernelDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I8;
using BiasKernelDataType = I8;
using ResidualKernelDataType = I8;
using OutKernelDataType = I8;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
......@@ -262,12 +262,12 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
template <index_t MPerWmma,
index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA,
class FloatB,
class FloatC>
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment