Commit 4086a4cc authored by wenjh's avatar wenjh
Browse files

Merge branch 'TE_ta_develop2.9' into 'develop_v2.9'

[DCU]Fix compilation unable to find nvte-extract_ded_and_offset

See merge request dcutoolkit/deeplearing/TransformerEngine!64
parents 177291ac 121d9224
......@@ -432,7 +432,7 @@ def test_fuser_ops_with_userbuffers(
command = []
if tex.ubuf_built_with_mpi():
python_exe = pathlib.Path(sys.executable).resolve()
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe))
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--allow-run-as-root", python_exe))
else:
command.extend(("torchrun", f"--nproc_per_node={world_size}"))
......
......@@ -219,6 +219,7 @@ else()
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
......
......@@ -44,7 +44,17 @@ __device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) {
// Bitwise logical op to get answer in MSBs
// Equivalent logic: result = (a == b) ? !result : b
#ifdef __HIP_PLATFORM_AMD__
result = (a == b) ? !result : b;
// Use HIP vector types for byte-wise parallel comparison
union { uint32_t u32; uint8_t bytes[4]; } a_union, b_union;
a_union.u32 = a;
b_union.u32 = b;
uint32_t mask = 0;
mask |= (a_union.bytes[0] < b_union.bytes[0]) ? 0x80000000U : 0;
mask |= (a_union.bytes[1] < b_union.bytes[1]) ? 0x00800000U : 0;
mask |= (a_union.bytes[2] < b_union.bytes[2]) ? 0x00008000U : 0;
mask |= (a_union.bytes[3] < b_union.bytes[3]) ? 0x00000080U : 0;
result = mask;
#else
asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result));
#endif
......
......@@ -14,7 +14,7 @@
namespace transformer_engine {
namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
// get matrix strides based on matrix type
......@@ -610,7 +610,7 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout;
}
#endif
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
......
......@@ -7,10 +7,11 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <cstdint>
#include <mutex>
......@@ -19,7 +20,7 @@
namespace transformer_engine {
namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
enum NVTE_QKV_Matrix {
......@@ -186,7 +187,7 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
#endif
} // namespace fused_attn
} // namespace transformer_engine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment