Commit ba058648 authored by tabuchixiangcai3's avatar tabuchixiangcai3
Browse files

[DCU]Fix compilation unable to find nvte-extract_ded_and_offset


Signed-off-by: tabuchixiangcai3's avatarTangao <2205747538@qq.com>
parent 177291ac
......@@ -219,6 +219,7 @@ else()
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
......
......@@ -14,7 +14,7 @@
namespace transformer_engine {
namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
// get matrix strides based on matrix type
......@@ -610,7 +610,7 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout;
}
#endif
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
......
......@@ -7,10 +7,11 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <cstdint>
#include <mutex>
......@@ -19,7 +20,7 @@
namespace transformer_engine {
namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
enum NVTE_QKV_Matrix {
......@@ -186,7 +187,7 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
#endif
} // namespace fused_attn
} // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment