flash_bwd_sumout.cpp 220 Bytes
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
#include <hip/hip_runtime.h>
#include "flash_sumout_api.h"

template __global__ void flash_sum_out<BFloat16>(BFloat16*, BFloat16*, int, int);
template __global__ void flash_sum_out<Float16>(Float16*, Float16*, int, int);