Commit 85762c1a authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

Init the main branch for aiter

parent ae0b3521
Pipeline #3505 canceled with stages
import torch
import os
from typing import Optional, List
import functools
from bisect import bisect_left
import aiter
from aiter import ActivationType, QuantType, dtypes
from aiter.jit.core import AITER_ROOT_DIR
from aiter import ck_moe, ck_shuffle_moe
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.jit.utils.chip_info import get_gfx
from aiter.fused_moe import moe_sorting
from aiter import per_token_quant_hip, per_block_quant_wrapper
BLOCK_SIZE_M = 32
class MoeQuantType:
NO_QUANT = "no_quant"
INT4_W4A16 = "int4_w4a16"
INT4_W4A8 = "int4_w4a8"
INT8_W8A8 = "int8_w8a8_block"
INT8_W8A8_C = "int8_w8a8_channel"
ALL_TYPES = [NO_QUANT, INT4_W4A16, INT4_W4A8, INT8_W8A8, INT8_W8A8_C]
@classmethod
def is_valid(cls, qtype_str: str) -> bool:
return qtype_str in cls.ALL_TYPES
@classmethod
def get_default(cls) -> str:
return cls.NO_QUANT
ck_tuned_file = os.path.join(AITER_ROOT_DIR, "aiter", "configs", "ck_tune", "tuned_fmoe_ck.csv")
ck_tuned_int8_w8a8_group_file = os.path.join(AITER_ROOT_DIR, "aiter", "configs", "ck_tune", "tuned_fmoe_ck_int8_w8a8_group.csv")
moe_ck_cfg = None
moe_ck_noquant_cfg = None
moe_ck_int8_w8a8_group_cfg = None
moe_ck_noquant_index = None
moe_ck_int8_w8a8_group_index = None
current_quant_type = None
def get_moe_ck_solution(
indtype,
token,
inter_dim,
model_dim,
expert,
topk,
quant_type,
q_size_n=0,
q_size_k=0
):
def get_moe_cfg(ck_tuned_file):
import pandas as pd
try:
moe_cfg = pd.read_csv(ck_tuned_file)
except Exception as e:
print(f">>> Warning: Failed to read config file {ck_tuned_file}: {e}")
return None
return moe_cfg
global moe_ck_cfg
if moe_ck_cfg is None:
moe_ck_cfg = get_moe_cfg(ck_tuned_file)
if moe_ck_cfg is None:
print(f">>> Warning: config file {ck_tuned_file} is not found, using default ck solution.")
return functools.partial(ck_moe, solution_id = 0)
mask = (
(moe_ck_cfg["indtype"] == str(indtype)) &
(moe_ck_cfg["inter_dim"] == inter_dim) &
(moe_ck_cfg["model_dim"] == model_dim) &
(moe_ck_cfg["expert"] == expert) &
(moe_ck_cfg["topk"] == topk) &
(moe_ck_cfg["quant_type"] == str(quant_type)) &
(moe_ck_cfg["q_size_n"] == q_size_n) &
(moe_ck_cfg["q_size_k"] == q_size_k)
)
matching_configs = moe_ck_cfg[mask]
if matching_configs.empty:
sol_id = 0
print(f">>> Warning: No matching config pattern found, using default ck solution.")
return functools.partial(ck_moe, solution_id=sol_id)
# 1. 精确匹配 token
exact_match = matching_configs[matching_configs["token"] == token]
if not exact_match.empty:
sol_id = int(exact_match.iloc[0]["sol_id"])
print(f">>> Info: Exact token match found for token={token}, using sol_id={sol_id}.")
return functools.partial(ck_moe, solution_id=sol_id)
# 2. 找最接近的 token
matching_configs["token_distance"] = abs(matching_configs["token"] - token)
closest_match = matching_configs.loc[matching_configs["token_distance"].idxmin()]
closest_token = closest_match["token"]
distance = closest_match["token_distance"]
sol_id = int(closest_match["sol_id"])
print(f">>> Info: Closest token match found: token={closest_token} (distance={distance}) for target token={token}, using sol_id={sol_id}.")
return functools.partial(ck_moe, solution_id=sol_id)
def build_moe_index(df):
"""Convert the tuning table into a pure-Python lookup structure."""
moe_index = {}
for row in df.itertuples(index=False):
key = (
row.arch,
int(row.inter_dim),
int(row.model_dim),
int(row.expert),
int(row.topk),
str(row.quant_type),
int(row.q_size_n),
int(row.q_size_k),
)
entry = moe_index.get(key)
if entry is None:
entry = {"token_to_sol": {}, "tokens": []}
moe_index[key] = entry
token_val = int(row.token)
entry["token_to_sol"][token_val] = int(row.sol_id)
entry["tokens"].append(token_val)
for entry in moe_index.values():
entry["tokens"].sort()
return moe_index
def _find_closest_token(sorted_tokens, target_token):
idx = bisect_left(sorted_tokens, target_token)
if idx == 0:
return sorted_tokens[0]
if idx == len(sorted_tokens):
return sorted_tokens[-1]
before = sorted_tokens[idx - 1]
after = sorted_tokens[idx]
if (target_token - before) <= (after - target_token):
return before
return after
def get_moe_ck_solution_id(
arch,
quant_type,
token,
inter_dim,
model_dim,
expert,
topk,
q_size_n=0,
q_size_k=0
):
def get_moe_cfg(ck_tuned_file):
import pandas as pd
try:
moe_cfg = pd.read_csv(ck_tuned_file)
except Exception as e:
print(f">>> Warning: Failed to read config file {ck_tuned_file}: {e}")
return None
return moe_cfg
global moe_ck_cfg, current_quant_type
global moe_ck_noquant_cfg, moe_ck_int8_w8a8_group_cfg
global moe_ck_noquant_index, moe_ck_int8_w8a8_group_index
current_index = None
if moe_ck_cfg is None or quant_type != current_quant_type:
if quant_type == MoeQuantType.INT8_W8A8:
if moe_ck_int8_w8a8_group_cfg is None:
moe_ck_int8_w8a8_group_cfg = get_moe_cfg(ck_tuned_int8_w8a8_group_file)
if moe_ck_int8_w8a8_group_cfg is not None:
moe_ck_int8_w8a8_group_index = build_moe_index(moe_ck_int8_w8a8_group_cfg)
moe_ck_cfg = moe_ck_int8_w8a8_group_cfg
elif quant_type == MoeQuantType.NO_QUANT:
if moe_ck_noquant_cfg is None:
moe_ck_noquant_cfg = get_moe_cfg(ck_tuned_file)
if moe_ck_noquant_cfg is not None:
moe_ck_noquant_index = build_moe_index(moe_ck_noquant_cfg)
moe_ck_cfg = moe_ck_noquant_cfg
else:
print(f">>> Warning: quant_type {quant_type} not supported for CK lookup, fallback to no-quant table.")
if moe_ck_noquant_cfg is None:
moe_ck_noquant_cfg = get_moe_cfg(ck_tuned_file)
if moe_ck_noquant_cfg is not None:
moe_ck_noquant_index = build_moe_index(moe_ck_noquant_cfg)
moe_ck_cfg = moe_ck_noquant_cfg
quant_type = MoeQuantType.NO_QUANT
current_quant_type = quant_type
if quant_type == MoeQuantType.INT8_W8A8:
current_index = moe_ck_int8_w8a8_group_index
else:
current_index = moe_ck_noquant_index
if moe_ck_cfg is None:
print(f">>> Warning: config file is not found, using default ck solution.")
return 0
if current_index is None:
print(f">>> Warning: ck index is not built, using default ck solution.")
return 0
key = (arch, inter_dim, model_dim, expert, topk, str(quant_type), q_size_n, q_size_k)
candidates = current_index.get(key)
if not candidates:
print(f">>> Warning: No matching config pattern found for key={key}, using default ck solution.")
return 0
# 1. 精确匹配 token
token = int(token)
token_to_sol = candidates["token_to_sol"]
sol_id = token_to_sol.get(token)
if sol_id is not None:
return int(sol_id)
# 2. 找最接近的 token
closest_token = _find_closest_token(candidates["tokens"], token)
sol_id = token_to_sol[closest_token]
return int(sol_id)
def ck_moe_stage_1(
hidden_states,
w1, # [E, inter_dim*2, model_dim]
w2, # [E, model_dim, inter_dim]
sorted_token_ids, # [max_num_tokens_padded]
sorted_expert_ids, # [max_num_m_blocks]
tokens_positions_per_expert, # [num_experts*2]
num_valid_ids, # [1]
use_int8_w8a8_block: bool,
use_fp8_w8a8_block: bool,
w1_scale,
a1_scale,
dtype,
topk,
block_shape_n=0,
block_shape_k=0,
block_size=16,
Activation=ActivationType.Silu,
sorted_weights=None, # [max_num_tokens_padded]
):
token_num = hidden_states.shape[0]
D = w1.shape[1]
# max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size
if Activation == ActivationType.Silu:
act_op = 1
else:
act_op = 0
if w1.dtype is torch.uint32:
D = D * 8
gemm_out_type = torch.float16
# for now, ck_moe_stage_1 has not do the activation inside, so 'D = 2 * inter_dim'
# out = torch.empty((token_num * topk, D), dtype=gemm_out_type, device=hidden_states.device)
out = torch.empty((token_num * topk, D//2), dtype=gemm_out_type, device=hidden_states.device)
aiter.ck_moe_stage_1(
hidden_states,
w1,
w2,
sorted_token_ids,
sorted_expert_ids,
tokens_positions_per_expert,
num_valid_ids,
out,
topk,
use_int8_w8a8_block,
use_fp8_w8a8_block,
w1_scale,
a1_scale,
block_shape_n,
block_shape_k,
block_size,
sorted_weights,
act_op,
)
# silu and multiply
# silu_out = torch.empty((token_num * topk, D // 2), dtype=dtype, device=hidden_states.device)
# aiter.silu_and_mul(silu_out, out.to(dtype))
# return silu_out
return out.to(dtype)
def ck_moe_stage_2(
hidden_states,
w1, # [E, inter_dim*2, model_dim]
w2, # [E, model_dim, inter_dim]
sorted_token_ids, # [max_num_tokens_padded]
sorted_expert_ids, # [max_num_m_blocks]
tokens_positions_per_expert, # [num_experts*2]
num_valid_ids, # [1]
use_int8_w8a8_block: bool,
use_fp8_w8a8_block: bool,
w2_scale,
a2_scale,
dtype,
topk,
block_shape_n=0,
block_shape_k=0,
block_size=16,
sorted_weights=None, # [max_num_tokens_padded]
moe_buf=None, # [token_num, model_dim]
):
hidden_states.reshape(-1, hidden_states.shape[-1])
if moe_buf is None:
out = torch.zeros( # must be zeros, because use atomic add inside
(hidden_states.shape[0]//topk, w2.shape[1]), # [token_num, model_dim]
dtype=dtypes.fp32, # gpu not support fp16 atomic add
device=hidden_states.device,
)
else:
out = moe_buf
aiter.ck_moe_stage_2(
hidden_states,
w1,
w2,
sorted_token_ids,
sorted_expert_ids,
tokens_positions_per_expert,
num_valid_ids,
out,
topk,
use_int8_w8a8_block,
use_fp8_w8a8_block,
w2_scale,
a2_scale,
block_shape_n,
block_shape_k,
block_size,
sorted_weights,
)
return out.to(dtype)
def fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a8_block: Optional[bool] = False,
use_int4_w4a8_block: Optional[bool] = False,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
solution_id: Optional[int] = 0,
expert_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
device = topk_ids.device
M, topk = topk_ids.shape
# dtype = dtype
# E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape)
# FIXME: W2.size must be same as hidden_dim
moe_buf = torch.empty((M, w2.size(1)), dtype=torch.float32, device=device)
return moe_buf
@torch_compile_guard(gen_fake=fused_moe_fake)
def ck_fused_experts_2stage_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
odtype:torch.dtype, #compute or output type for i8& f8
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
solution_id: Optional[int] = None)-> torch.Tensor:
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
_, model_dim, inter_dim = w2.shape
top_k_num = topk_ids.shape[1]
quant_block_n, quant_block_k = block_shape[0],block_shape[1] if block_shape is not None else (0,0)
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, tokens_positions_per_expert, moe_buf = moe_sorting(
topk_ids, topk_weights, E, model_dim, torch.float32, BLOCK_SIZE_M
)
# print(f"########### token_per_expert: {tokens_positions_per_expert}")
if use_int8_w8a8:
if per_channel_quant:
print(">>> ck fused moe int8 w8a8 per channel not supported yet.")
return None
else: # block scale
# quantization input if needed
if hidden_states.dtype == torch.float16 or hidden_states.dtype==torch.bfloat16:
input_q, input_scale = per_block_quant_wrapper((1, quant_block_k))(per_token_quant_hip)(hidden_states, quant_dtype=torch.int8)
else:
input_q, input_scale = hidden_states, a1_scale
out_st1 = ck_moe_stage_1(
input_q, # 暂时由外部quant input
w1,
w2,
sorted_ids,
sorted_expert_ids,
tokens_positions_per_expert,
num_valid_ids,
True,
False,
w1_scale,
input_scale,
odtype, # fp16/bf16 compute
top_k_num,
block_shape_n=quant_block_n,
block_shape_k=quant_block_k,
block_size=BLOCK_SIZE_M,
Activation=ActivationType.Silu if activation=="silu" else ActivationType.Gelu,
sorted_weights=None) # stage1不处理topk weights
# quantization stage1 output
out_st1 = out_st1.reshape(-1, out_st1.shape[-1])
bridge_q, bridge_scale = per_block_quant_wrapper((1, quant_block_k))(per_token_quant_hip)(out_st1, quant_dtype=torch.int8)
out = ck_moe_stage_2(
bridge_q,
w1,
w2,
sorted_ids,
sorted_expert_ids,
tokens_positions_per_expert,
num_valid_ids,
True,
False,
w2_scale,
bridge_scale,
odtype, # fp16/bf16 compute
top_k_num,
block_shape_n=quant_block_n,
block_shape_k=quant_block_k,
block_size=BLOCK_SIZE_M,
sorted_weights=sorted_weights, # stage2处理topk weights
moe_buf=moe_buf
)
# return (out, out_st1)
return out
elif use_fp8_w8a8:
if per_channel_quant:
print(">>> ck fused moe fp8 w8a8 per channel not supported yet.")
return None
else:
# quantization input if needed
if hidden_states.dtype == torch.float16 or hidden_states.dtype==torch.bfloat16:
input_q, input_scale = per_block_quant_wrapper((1, quant_block_k))(per_token_quant_hip)(hidden_states, quant_dtype=torch.float8_e4m3fn)
else:
input_q, input_scale = hidden_states, a1_scale
out_st1 = ck_moe_stage_1(
input_q, # 暂时由外部quant input
w1,
w2,
sorted_ids,
sorted_expert_ids,
tokens_positions_per_expert,
num_valid_ids,
False,
True,
w1_scale,
input_scale,
odtype, # fp16/bf16 compute
top_k_num,
block_shape_n=quant_block_n,
block_shape_k=quant_block_k,
block_size=BLOCK_SIZE_M,
Activation=ActivationType.Silu if activation=="silu" else ActivationType.Gelu,
sorted_weights=None) # stage1不处理topk weights
# quantization stage1 output
out_st1 = out_st1.reshape(-1, out_st1.shape[-1])
bridge_q, bridge_scale = per_block_quant_wrapper((1, quant_block_k))(per_token_quant_hip)(out_st1, quant_dtype=torch.float8_e4m3fn)
out = ck_moe_stage_2(
bridge_q,
w1,
w2,
sorted_ids,
sorted_expert_ids,
tokens_positions_per_expert,
num_valid_ids,
False,
True,
w2_scale,
bridge_scale,
odtype, # fp16/bf16 compute
top_k_num,
block_shape_n=quant_block_n,
block_shape_k=quant_block_k,
block_size=BLOCK_SIZE_M,
sorted_weights=sorted_weights, # stage2处理topk weights
moe_buf=moe_buf
)
# return (out, out_st1)
return out
else:
return None
@torch_compile_guard(gen_fake=fused_moe_fake)
def ck_fused_experts_1stage_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
odtype:torch.dtype, #compute or output type for i8& f8
use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a8_block: Optional[bool] = False,
use_int4_w4a8_block: Optional[bool] = False,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape_n: Optional[int] = 0,
block_shape_k: Optional[int] = 0,
block_m: Optional[int] = 32,
use_shuffle: Optional[bool] = False,
solution_id: Optional[int] = 0,
expert_mask: Optional[torch.Tensor] = None)-> torch.Tensor:
if use_shuffle and use_shuffle==True:
out = ck_shuffle_moe(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
use_int8_w8a16,
use_int4_w4a16,
use_int8_w8a8_block,
use_int4_w4a8_block,
w1_zp,
w2_zp,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
block_shape_n,
block_shape_k,
block_m,
solution_id,
expert_mask)
return out.to(odtype)
else:
out = ck_moe(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
use_int8_w8a16,
use_int4_w4a16,
use_int8_w8a8_block,
use_int4_w4a8_block,
w1_zp,
w2_zp,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
block_shape_n,
block_shape_k,
block_m,
solution_id,
expert_mask)
return out.to(odtype)
# sum_out = torch.empty_like(hidden_states, dtype=out.dtype, device=out.device)
# moe_sum(out, sum_out)
# return sum_out.to(odtype)
def bits30_31(solution_id: int) -> int:
unsigned32 = solution_id & 0xFFFFFFFF # treat as 32-bit two’s complement
return (unsigned32 & 0xC0000000) >> 30 # 0xC0000000 = bits 31–30 set
# The outside interface
def run_fused_experts_ck_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
odtype:torch.dtype, #compute or output type for i8& f8
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
block_m: int = BLOCK_SIZE_M,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_shuffle: Optional[bool] = False,
routed_scaling_factor: Optional[float] = 1.0,
solution_id: Optional[int] = None)-> torch.Tensor:
if solution_id == None:
if use_shuffle and use_shuffle==True: #only one stage supports shuffle for now.
solution_id = 0
else:
# solution_id = 0
arch = get_gfx()
quantType = MoeQuantType.NO_QUANT
if use_int8_w8a8:
quantType = MoeQuantType.INT8_W8A8
E, model_dim, inter_dim = w2.shape
topk = topk_ids.shape[1]
if quantType == MoeQuantType.INT8_W8A8 and block_shape[1] == 64:
solution_id = 1 << 30 # only two stage supports block_shape_k = 64
else:
solution_id = get_moe_ck_solution_id(
arch,
quantType,
hidden_states.shape[0],
inter_dim, # inter_dim
model_dim,
E,
topk,
block_shape[0] if block_shape is not None else 0,
block_shape[1] if block_shape is not None else 0
)
solutionType = bits30_31(solution_id)
# two stage
if solutionType == 1:
return ck_fused_experts_2stage_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
odtype,
inplace,
activation,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_int4_w4a8,
per_channel_quant,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
solution_id)
# one stage
else:
return ck_fused_experts_1stage_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
odtype,
use_int8_w8a16,
use_int4_w4a16,
use_int8_w8a8,
use_int4_w4a8,
w1_zp,
w2_zp,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
block_shape[0] if block_shape is not None else 0,
block_shape[1] if block_shape is not None else 0,
block_m,
use_shuffle,
solution_id,
expert_map)
\ No newline at end of file
import torch
from aiter import dtypes
# packed_4_bits (pack) = [0, 2, 4, 6, 1, 3, 5, 7]
# (unpack) = [0, 4, 1, 5, 2, 6, 3, 7]
# This code is adapted from https://github.com/ROCm/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py
# zeros are ignored since we use symmetric quantization
# qweight is both quantized and bit-packed alone the same row. All the bits in the same row has the same scaling factor.
# 8 INT4s are packed into one INT32. INT4 instead of UINT4 is used.
################################################################################
# Custom Triton Kernel & Wrapper
################################################################################
def convert_int8_to_uint32_int4(tensor: torch.Tensor) -> torch.Tensor:
assert tensor.dtype == dtypes.i8, "input should be int8"
if tensor.shape[-1] % 8 != 0:
raise ValueError("k % 8 should be zero")
tensor_reshaped = tensor.reshape(*tensor.shape[:-1], tensor.shape[-1] // 8, 8)
high_bits = tensor_reshaped & 0x0F
merged = (
(high_bits[:, :, :, 7].to(dtypes.i32) << 28)
| (high_bits[:, :, :, 6].to(dtypes.i32) << 24)
| (high_bits[:, :, :, 5].to(dtypes.i32) << 20)
| (high_bits[:, :, :, 4].to(dtypes.i32) << 16)
| (high_bits[:, :, :, 3].to(dtypes.i32) << 12)
| (high_bits[:, :, :, 2].to(dtypes.i32) << 8)
| (high_bits[:, :, :, 1].to(dtypes.i32) << 4)
| high_bits[:, :, :, 0].to(dtypes.i32)
)
return merged.view(dtype=torch.uint32)
def rearrange_4bit_elements(tensor):
"""
GPU-optimized version for rearranging 4-bit segments within 32-bit integers
[e0, e1, e2, e3, e4, e5, e6, e7] -> [e0, e2, e4, e6, e1, e3, e5, e7]
"""
t_ = tensor.view(dtype=dtypes.i32)
return (
((t_ & 0xF0000000) << 0) # e0 (bits 28-31)
| ((t_ & 0x00F00000) << 4) # e2 -> position 24-27
| ((t_ & 0x0000F000) << 8) # e4 -> position 20-23
| ((t_ & 0x000000F0) << 12) # e6 -> position 16-19
| ((t_ & 0x0F000000) >> 12) # e1 -> position 12-15
| ((t_ & 0x000F0000) >> 8) # e3 -> position 8-11
| ((t_ & 0x00000F00) >> 4) # e5 -> position 4-7
| (t_ & 0x0000000F) # e7 (bits 0-3)
).view(dtype=torch.uint32)
# SPDX-License-Identifier: MIT
\ No newline at end of file
# SPDX-License-Identifier: MIT
import functools
import importlib
import json
import logging
import multiprocessing
import os
import re
import shutil
import sys
import time
import traceback
import types
import typing
import copy
from typing import Any, Callable, List, Optional
from packaging.version import Version, parse
this_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, f"{this_dir}/utils/")
from chip_info import get_gfx
from cpp_extension import _jit_compile, executable_path, get_hip_version
from file_baton import FileBaton
from torch_guard import torch_compile_guard # noqa: E402
AITER_REBUILD = int(os.environ.get("AITER_REBUILD", "0"))
aiter_lib = None
def mp_lock(
lockPath: str,
MainFunc: Callable,
FinalFunc: Optional[Callable] = None,
WaitFunc: Optional[Callable] = None,
):
"""
Using FileBaton for multiprocessing.
"""
baton = FileBaton(lockPath)
if baton.try_acquire():
try:
ret = MainFunc()
finally:
if FinalFunc is not None:
FinalFunc()
baton.release()
else:
baton.wait()
if WaitFunc is not None:
ret = WaitFunc()
ret = None
return ret
logger = logging.getLogger("aiter")
PY = sys.executable
this_dir = os.path.dirname(os.path.abspath(__file__))
AITER_ROOT_DIR = os.path.abspath(f"{this_dir}/../../")
AITER_LOG_MORE = int(os.getenv("AITER_LOG_MORE", 0))
AITER_LOG_TUNED_CONFIG = int(os.getenv("AITER_LOG_TUNED_CONFIG", 0))
# config_env start here
def update_config_files(file_path: str, merge_name: str):
path_list = file_path.split(os.pathsep) if file_path else []
if len(path_list) <= 1:
return file_path
df_list = []
## merge config files
##example: AITER_CONFIG_GEMM_A4W4="/path1:/path2"
import pandas as pd
df_list.append(pd.read_csv(path_list[0]))
for i, path in enumerate(path_list[1:]):
if os.path.exists(path):
df = pd.read_csv(path)
## check columns
assert (
df.columns.tolist() == df_list[0].columns.tolist()
), f"Column mismatch between {path_list[0]} and {path}, {df_list[0].columns.tolist()}, {df.columns.tolist()}"
df_list.append(df)
else:
logger.info(f"path {i+1}: {path} (not exist)")
merge_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()
## get keys from untuned file to drop_duplicates
untuned_name = (
re.sub(r"(?:_)?tuned$", r"\1untuned", merge_name)
if re.search(r"(?:_)?tuned$", merge_name)
else merge_name.replace("tuned", "untuned")
)
untuned_path = f"{AITER_ROOT_DIR}/aiter/configs/{untuned_name}.csv"
if os.path.exists(untuned_path):
untunedf = pd.read_csv(untuned_path)
keys = untunedf.columns
merge_df = (
merge_df.sort_values("us")
.drop_duplicates(subset=keys, keep="first")
.reset_index(drop=True)
)
else:
logger.warning(
f"Untuned config file not found: {untuned_path}. Using all columns for deduplication."
)
new_file_path = f"/tmp/{merge_name}.csv"
merge_df.to_csv(new_file_path, index=False)
return new_file_path
def get_config_file(env_name, default_file, tuned_file_name):
config_env_file = os.getenv(env_name)
# default_file = f"{AITER_ROOT_DIR}/aiter/configs/{tuned_file_name}.csv"
from pathlib import Path
if not config_env_file:
model_config_dir = Path(f"{AITER_ROOT_DIR}/aiter/configs/model_configs/")
op_tuned_file_list = [
p
for p in model_config_dir.glob(f"*{tuned_file_name}*")
if (p.is_file() and "untuned" not in str(p))
]
if not op_tuned_file_list:
config_file = default_file
else:
tuned_files = ":".join(str(p) for p in op_tuned_file_list)
tuned_files = default_file + ":" + tuned_files
logger.info(
f"merge tuned file under model_configs/ and configs/ {tuned_files}"
)
config_file = update_config_files(tuned_files, tuned_file_name)
else:
config_file = update_config_files(config_env_file, tuned_file_name)
# print(f"get config file from environment ", config_file)
return config_file
AITER_CONFIG_GEMM_A4W4 = os.getenv(
"AITER_CONFIG_GEMM_A4W4",
f"{AITER_ROOT_DIR}/aiter/configs/a4w4_blockscale_tuned_gemm.csv",
)
AITER_CONFIG_GEMM_A8W8 = os.getenv(
"AITER_CONFIG_GEMM_A8W8",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_tuned_gemm.csv",
)
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE = os.getenv(
"AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv",
)
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE = os.getenv(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_blockscale_tuned_gemm.csv",
)
AITER_CONFIG_FMOE = os.getenv(
"AITER_CONFIG_FMOE",
f"{AITER_ROOT_DIR}/aiter/configs/tuned_fmoe.csv",
)
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE = os.getenv(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_blockscale_bpreshuffle_tuned_gemm.csv",
)
AITER_CONFIG_A8W8_BATCHED_GEMM = os.getenv(
"AITER_CONFIG_A8W8_BATCHED_GEMM",
f"{AITER_ROOT_DIR}/aiter/configs/a8w8_tuned_batched_gemm.csv",
)
AITER_CONFIG_BF16_BATCHED_GEMM = os.getenv(
"AITER_CONFIG_BF16_BATCHED_GEMM",
f"{AITER_ROOT_DIR}/aiter/configs/bf16_tuned_batched_gemm.csv",
)
AITER_CONFIG_GEMM_BF16 = os.getenv(
"AITER_CONFIG_GEMM_BF16",
f"{AITER_ROOT_DIR}/aiter/configs/tuned_gemm.csv",
)
AITER_CONFIG_GEMM_A4W4_FILE = get_config_file(
"AITER_CONFIG_GEMM_A4W4", AITER_CONFIG_GEMM_A4W4, "a4w4_blockscale_tuned_gemm"
)
AITER_CONFIG_GEMM_A8W8_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8", AITER_CONFIG_GEMM_A8W8, "a8w8_tuned_gemm"
)
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE",
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE,
"a8w8_bpreshuffle_tuned_gemm",
)
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE",
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE,
"a8w8_blockscale_tuned_gemm",
)
AITER_CONFIG_FMOE_FILE = get_config_file(
"AITER_CONFIG_FMOE", AITER_CONFIG_FMOE, "tuned_fmoe"
)
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE = get_config_file(
"AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE",
AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE,
"a8w8_blockscale_bpreshuffle_tuned_gemm",
)
AITER_CONFIG_A8W8_BATCHED_GEMM_FILE = get_config_file(
"AITER_CONFIG_A8W8_BATCHED_GEMM",
AITER_CONFIG_A8W8_BATCHED_GEMM,
"a8w8_tuned_batched_gemm",
)
AITER_CONFIG_BF16_BATCHED_GEMM_FILE = get_config_file(
"AITER_CONFIG_BF16_BATCHED_GEMM",
AITER_CONFIG_BF16_BATCHED_GEMM,
"bf16_tuned_batched_gemm",
)
AITER_CONFIG_GEMM_BF16_FILE = get_config_file(
"AITER_CONFIG_GEMM_BF16", AITER_CONFIG_GEMM_BF16, "bf16_tuned_gemm"
)
# config_env end here
find_aiter = importlib.util.find_spec("aiter")
if find_aiter is not None:
if find_aiter.submodule_search_locations:
package_path = find_aiter.submodule_search_locations[0]
elif find_aiter.origin:
package_path = find_aiter.origin
package_path = os.path.dirname(package_path)
package_parent_path = os.path.dirname(package_path)
try:
with open(f"{this_dir}/../install_mode", "r") as f:
# develop mode
isDevelopMode = f.read().strip() == "develop"
except FileNotFoundError:
# pip install -e
isDevelopMode = True
if isDevelopMode:
AITER_META_DIR = AITER_ROOT_DIR
# install mode
else:
AITER_META_DIR = os.path.abspath(f"{AITER_ROOT_DIR}/aiter_meta/")
else:
AITER_META_DIR = AITER_ROOT_DIR
logger.warning("aiter is not installed.")
sys.path.insert(0, AITER_META_DIR)
AITER_CSRC_DIR = f"{AITER_META_DIR}/csrc"
AITER_GRADLIB_DIR = f"{AITER_META_DIR}/gradlib"
gfx = get_gfx()
AITER_ASM_DIR = f"{AITER_META_DIR}/hsa/{gfx}/"
os.environ["AITER_ASM_DIR"] = AITER_ASM_DIR
CK_3RDPARTY_DIR = os.environ.get(
"CK_DIR", f"{AITER_META_DIR}/3rdparty/composable_kernel"
)
CK_DIR = CK_3RDPARTY_DIR
MOE_C_3RDPARTY_DIR = os.environ.get(
"MOE_C_DIR", f"{AITER_META_DIR}/3rdparty/moe_c"
)
MOE_C_DIR = MOE_C_3RDPARTY_DIR
os.environ["AITER_META_DIR"] = AITER_META_DIR
@functools.lru_cache(maxsize=1)
def get_asm_dir():
return AITER_ASM_DIR
@functools.lru_cache(maxsize=1)
def get_user_jit_dir() -> str:
if "AITER_JIT_DIR" in os.environ:
path = os.getenv("AITER_JIT_DIR", "")
os.makedirs(path, exist_ok=True)
sys.path.insert(0, path)
return path
else:
if os.access(this_dir, os.W_OK):
return this_dir
home_jit_dir = f"{os.path.expanduser('~')}/.aiter/{os.path.basename(this_dir)}"
if not os.path.exists(home_jit_dir):
shutil.copytree(this_dir, home_jit_dir)
return home_jit_dir
bd_dir = f"{get_user_jit_dir()}/build"
# copy ck to build, thus hippify under bd_dir
if multiprocessing.current_process().name == "MainProcess":
os.makedirs(bd_dir, exist_ok=True)
# if os.path.exists(f"{bd_dir}/ck/library"):
# shutil.rmtree(f"{bd_dir}/ck/library")
# CK_DIR = f"{bd_dir}/ck"
def validate_and_update_archs():
archs = os.getenv("GPU_ARCHS", "native").split(";")
archs = [arch.strip() for arch in archs]
# List of allowed architectures
allowed_archs = [
"native",
"gfx90a",
"gfx940",
"gfx941",
"gfx942",
"gfx1100",
"gfx950",
"gfx928",
"gfx936",
"gfx938",
"gfx946",
]
# Validate if each element in archs is in allowed_archs
assert all(
arch in allowed_archs for arch in archs
), f"One of GPU archs of {archs} is invalid or not supported"
return archs
@functools.lru_cache()
def hip_flag_checker(flag_hip: str) -> bool:
hipcc = executable_path("hipcc")
ret = os.system(f'"{hipcc}" {flag_hip} -x hip -E -P /dev/null -o /dev/null')
if ret == 0:
return True
else:
logger.warning(f"{flag_hip} is not supported by hipcc.")
return False
def _path_under_prefix(path: str, prefix: str) -> bool:
if not path:
return False
rp = os.path.realpath(path)
rprefix = os.path.realpath(prefix)
try:
common = os.path.commonpath([rp, rprefix])
except ValueError:
return False
return common == rprefix
@functools.lru_cache(maxsize=1)
def detect_dtk_env() -> bool:
# Simplified detection logic:
# 1) If 'aicc' is present (in PATH or at /opt/dtk/bin/aicc), treat it as hipcc alias and use it for compilation.
# 2) Otherwise fall back to the normal hipcc resolution (executable_path("hipcc")).
# DTK environment is determined when the selected hipcc (or ROCM_PATH) is under /opt/dtk.
# Try to locate 'aicc' first (DTK's renamed hipcc)
aicc_path = shutil.which("aicc")
if not aicc_path:
candidate = "/opt/dtk/bin/aicc"
if os.path.exists(candidate):
aicc_path = os.path.realpath(candidate)
hipcc = ""
hipcc_in_dtk = False
if aicc_path:
# Use aicc as the hipcc implementation by exporting HIPCC so other code that calls executable_path("hipcc")
# will pick up the aicc binary.
hipcc = os.path.realpath(aicc_path)
os.environ["HIPCC"] = hipcc
hipcc_in_dtk = _path_under_prefix(hipcc, "/opt/dtk")
logger.info(f"Found 'aicc' and using it as hipcc: {hipcc}")
else:
# Fallback to normal hipcc resolution (may raise/abort in executable_path)
try:
hipcc = executable_path("hipcc")
except Exception:
# If executable_path fails, try a best-effort lookup via shutil.which
hipcc = shutil.which("hipcc") or ""
if hipcc:
hipcc = os.path.realpath(hipcc)
hipcc_in_dtk = _path_under_prefix(hipcc, "/opt/dtk") if hipcc else False
# Also consider ROCM_PATH pointing under /opt/dtk
rocm_path = os.getenv("ROCM_PATH", "")
rocm_in_dtk = _path_under_prefix(rocm_path, "/opt/dtk")
enabled = hipcc_in_dtk or rocm_in_dtk
if enabled:
logger.info(
f"DTK environment detected (hipcc={hipcc}, ROCM_PATH={rocm_path}), enabling -DDTK_ENV"
)
else:
logger.info(
f"Non-DTK environment (hipcc={hipcc}, ROCM_PATH={rocm_path}), DTK_ENV disabled"
)
return enabled
def check_and_set_ninja_worker():
max_num_jobs_cores = max(1, os.cpu_count() * 0.8)
import psutil
# calculate the maximum allowed NUM_JOBS based on free memory
free_memory_gb = psutil.virtual_memory().available / (1024**3) # free memory in GB
max_num_jobs_memory = int(free_memory_gb / 0.5) # assuming 0.5 GB per job
# pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
max_jobs = int(max(1, min(max_num_jobs_cores, max_num_jobs_memory)))
max_jobs_env = os.environ.get("MAX_JOBS")
if max_jobs_env is not None:
try:
max_processes = int(max_jobs_env)
# too large value
if max_processes > max_jobs:
os.environ["MAX_JOBS"] = str(max_jobs)
# error value
except ValueError:
os.environ["MAX_JOBS"] = str(max_jobs)
# none value
else:
os.environ["MAX_JOBS"] = str(max_jobs)
def rename_cpp_to_cu(els, dst, hipify, recursive=False):
def do_rename_and_mv(name, src, dst, ret):
newName = name
if hipify:
if name.endswith(".cpp") or name.endswith(".cu"):
newName = name.replace(".cpp", ".cu")
ret.append(f"{dst}/{newName}")
shutil.copy(f"{src}/{name}", f"{dst}/{newName}")
else:
if name.endswith(".cpp") or name.endswith(".cu"):
ret.append(f"{src}/{newName}")
ret = []
for el in els:
if not os.path.exists(el):
logger.warning(f"---> {el} not exists!!!!!!")
continue
if os.path.isdir(el):
for entry in os.listdir(el):
if os.path.isdir(f"{el}/{entry}"):
if recursive:
ret += rename_cpp_to_cu(
[f"{el}/{entry}"], dst, hipify, recursive
)
continue
do_rename_and_mv(entry, el, dst, ret)
else:
do_rename_and_mv(os.path.basename(el), os.path.dirname(el), dst, ret)
return ret
@torch_compile_guard()
def check_numa_custom_op() -> None:
numa_balance_set = os.popen("cat /proc/sys/kernel/numa_balancing").read().strip()
if numa_balance_set == "1":
logger.warning(
"WARNING: NUMA balancing is enabled, which may cause errors. "
"It is recommended to disable NUMA balancing by running \"sudo sh -c 'echo 0 > /proc/sys/kernel/numa_balancing'\" "
)
@functools.lru_cache()
def check_numa():
check_numa_custom_op()
__mds = {}
@torch_compile_guard()
def get_module_custom_op(md_name: str) -> None:
global __mds
if md_name not in __mds:
if "AITER_JIT_DIR" in os.environ:
__mds[md_name] = importlib.import_module(md_name)
else:
__mds[md_name] = importlib.import_module(f"{__package__}.{md_name}")
if AITER_LOG_MORE:
logger.info(f"import [{md_name}] under {__mds[md_name].__file__}")
return
@functools.lru_cache(maxsize=1024)
def get_module(md_name):
check_numa()
get_module_custom_op(md_name)
return __mds[md_name]
rebuilded_list = ["module_aiter_enum"]
def rm_module(md_name):
os.system(f"rm -rf {get_user_jit_dir()}/{md_name}.so")
def clear_build(md_name):
os.system(f"rm -rf {bd_dir}/{md_name}")
def build_module(
md_name,
srcs,
flags_extra_cc,
flags_extra_hip,
blob_gen_cmd,
extra_include,
extra_ldflags,
verbose,
is_python_module,
is_standalone,
torch_exclude,
hipify=False,
):
lock_path = f"{bd_dir}/lock_{md_name}"
startTS = time.perf_counter()
target_name = f"{md_name}.so" if not is_standalone else md_name
def MainFunc():
if AITER_REBUILD == 1:
rm_module(md_name)
clear_build(md_name)
elif AITER_REBUILD >= 2:
rm_module(md_name)
op_dir = f"{bd_dir}/{md_name}"
logger.info(f"start build [{md_name}] under {op_dir}")
opbd_dir = f"{op_dir}/build"
src_dir = f"{op_dir}/build/srcs"
os.makedirs(src_dir, exist_ok=True)
if os.path.exists(f"{get_user_jit_dir()}/{target_name}"):
os.remove(f"{get_user_jit_dir()}/{target_name}")
sources = rename_cpp_to_cu(srcs, src_dir, hipify)
flags_cc = ["-O3", "-std=c++20"]
flags_hip = [
# "-DLEGACY_HIPBLAS_DIRECT",
"-DUSE_PROF_API=1",
"-D__HIP_PLATFORM_HCC__=1",
"-D__HIP_PLATFORM_AMD__=1",
"-U__HIP_NO_HALF_CONVERSIONS__",
"-U__HIP_NO_HALF_OPERATORS__",
"-mllvm --amdgpu-kernarg-preload-count=16",
# "-v --save-temps",
"-Wno-unused-result",
"-Wno-switch-bool",
"-Wno-vla-cxx-extension",
"-Wno-undefined-func-template",
"-Wno-macro-redefined",
"-Wno-missing-template-arg-list-after-template-kw",
"-fgpu-flush-denormals-to-zero",
]
# Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214
hip_version = parse(get_hip_version().split()[-1].rstrip("-").replace("-", "+"))
if hip_version > Version("5.5.00000"):
flags_hip += ["-mllvm --lsr-drop-solution=1"]
if hip_version > Version("5.7.23302"):
flags_hip += ["-fno-offload-uniform-block"]
if hip_version > Version("6.1.40090"):
flags_hip += ["-mllvm -enable-post-misched=0"]
if hip_version > Version("6.2.41132"):
flags_hip += [
"-mllvm -amdgpu-early-inline-all=true",
"-mllvm -amdgpu-function-calls=false",
]
if hip_version > Version("6.2.41133"):
flags_hip += ["-mllvm -amdgpu-coerce-illegal-types=1"]
if get_gfx() == "gfx946" and int(os.getenv("AITER_FP4x2", "1")) > 0:
flags_hip += ["-D__Float4_e2m1fn_x2"]
if not torch_exclude:
import torch
if hasattr(torch, "float4_e2m1fn_x2"):
flags_hip += ["-DTORCH_Float4_e2m1fn_x2"]
# Enable DTK code path only when hipcc/ROCM_PATH indicates /opt/dtk
if detect_dtk_env():
flags_cc.append("-DDTK_ENV")
flags_hip.append("-DDTK_ENV")
flags_cc += flags_extra_cc
flags_hip += flags_extra_hip
archs = validate_and_update_archs()
flags_hip += [f"--offload-arch={arch}" for arch in archs]
if any(arch == "gfx938" for arch in archs) or get_gfx()=="gfx938":
flags_hip.append("-DGPU_ENABLE_FP8") # device
flags_cc.append("-DGPU_ENABLE_FP8") # host
flags_hip = sorted(set(flags_hip)) # remove same flags
flags_hip = [el for el in flags_hip if hip_flag_checker(el)]
check_and_set_ninja_worker()
def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
if blob_gen_cmd:
blob_dir = f"{op_dir}/blob"
os.makedirs(blob_dir, exist_ok=True)
if AITER_LOG_MORE:
logger.info(f"exec_blob ---> {PY} {blob_gen_cmd.format(blob_dir)}")
os.system(f"{PY} {blob_gen_cmd.format(blob_dir)}")
sources += rename_cpp_to_cu([blob_dir], src_dir, hipify, recursive=True)
return sources
if isinstance(blob_gen_cmd, list):
for s_blob_gen_cmd in blob_gen_cmd:
sources = exec_blob(s_blob_gen_cmd, op_dir, src_dir, sources)
else:
sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources)
extra_include_paths = [
f"{CK_DIR}/include",
f"{CK_DIR}/library/include",
]
if not hipify:
extra_include_paths += [
f"{AITER_CSRC_DIR}/include",
f"{op_dir}/blob",
] + extra_include
if not is_standalone:
extra_include_paths += [f"{AITER_CSRC_DIR}/include/torch"]
else:
old_bd_include_dir = f"{op_dir}/build/include"
extra_include_paths.append(old_bd_include_dir)
os.makedirs(old_bd_include_dir, exist_ok=True)
rename_cpp_to_cu(
[f"{AITER_CSRC_DIR}/include"] + extra_include,
old_bd_include_dir,
hipify,
)
if not is_standalone:
bd_include_dir = f"{op_dir}/build/include/torch"
os.makedirs(bd_include_dir, exist_ok=True)
rename_cpp_to_cu(
[f"{AITER_CSRC_DIR}/include/torch"],
bd_include_dir,
hipify,
)
try:
_jit_compile(
md_name,
sorted(set(sources)),
extra_cflags=flags_cc,
extra_cuda_cflags=flags_hip,
extra_ldflags=extra_ldflags,
extra_include_paths=extra_include_paths,
build_directory=opbd_dir,
verbose=verbose or AITER_LOG_MORE > 1,
with_cuda=True,
is_python_module=is_python_module,
is_standalone=is_standalone,
torch_exclude=torch_exclude,
hipify=hipify,
)
if is_python_module and not is_standalone:
shutil.copy(f"{opbd_dir}/{target_name}", f"{get_user_jit_dir()}")
else:
shutil.copy(
f"{opbd_dir}/{target_name}", f"{AITER_ROOT_DIR}/op_tests/cpp/mha"
)
except Exception as e:
tag = f"\033[31mfailed jit build [{md_name}]\033[0m"
logger.error(
f"{tag}\u2193\u2193\u2193\u2193\u2193\u2193\u2193\u2193\u2193\u2193\n-->[History]: {{}}{tag}\u2191\u2191\u2191\u2191\u2191\u2191\u2191\u2191\u2191\u2191".format(
re.sub(
"error:",
"\033[31merror:\033[0m",
"-->".join(traceback.format_exception(*sys.exc_info())),
flags=re.I,
),
)
)
raise SystemExit(
f"[aiter] build [{md_name}] under {opbd_dir} failed !!!!!!"
) from e
def FinalFunc():
logger.info(
f"\033[32mfinish build [{md_name}], cost {time.perf_counter()-startTS:.1f}s \033[0m"
)
mp_lock(lockPath=lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc)
def get_args_of_build(ops_name: str, exclude=[]):
d_opt_build_args = {
"srcs": [],
"md_name": "",
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": None,
"extra_include": [],
"verbose": False,
"hipify": False,
"is_python_module": True,
"is_standalone": False,
"torch_exclude": False,
"hip_clang_path": None,
"blob_gen_cmd": "",
"skip_if": False,
}
def convert(d_ops: dict):
converted_ops = {}
for k, val in d_ops.items():
if isinstance(val, list):
converted_list = list(val)
for idx, el in enumerate(val):
if isinstance(el, str):
if "torch" in el:
import torch as torch
converted_list[idx] = eval(el)
else:
converted_list[idx] = el
converted_ops[k] = converted_list
elif isinstance(val, str):
converted_ops[k] = eval(val)
else:
converted_ops[k] = val
# undefined compile features will be replaced with default value
resolved_build_args = copy.deepcopy(d_opt_build_args)
resolved_build_args.update(converted_ops)
return resolved_build_args
with open(this_dir + "/optCompilerConfig.json", "r") as file:
data = json.load(file)
if isinstance(data, dict):
# parse all ops, return list
if ops_name == "all":
all_ops_list = []
d_all_ops = {
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_include": [],
"extra_ldflags": [],
"blob_gen_cmd": [],
}
# traverse opts
for ops_name, d_ops in data.items():
# Cannot contain tune ops
if ops_name.endswith("tune"):
continue
# exclude
if ops_name in exclude:
continue
single_ops = convert(d_ops)
d_single_ops = {
"md_name": ops_name,
"srcs": single_ops["srcs"],
"flags_extra_cc": single_ops["flags_extra_cc"],
"flags_extra_hip": single_ops["flags_extra_hip"],
"extra_include": single_ops["extra_include"],
"extra_ldflags": single_ops["extra_ldflags"],
"blob_gen_cmd": single_ops["blob_gen_cmd"],
"verbose": single_ops["verbose"],
"hipify": single_ops["hipify"],
"skip_if": single_ops.get("skip_if", False),
}
for k in d_all_ops.keys():
if isinstance(single_ops[k], list):
d_all_ops[k] += single_ops[k]
elif isinstance(single_ops[k], str) and single_ops[k] != "":
d_all_ops[k].append(single_ops[k])
all_ops_list.append(d_single_ops)
return all_ops_list, d_all_ops
# no find opt_name in json.
elif data.get(ops_name) is None:
logger.warning(
"Not found this operator ("
+ ops_name
+ ") in 'optCompilerConfig.json'. "
)
return d_opt_build_args
# parser single opt
else:
compile_ops_ = data.get(ops_name)
return convert(compile_ops_)
else:
logger.warning(
"ERROR: pls use dict_format to write 'optCompilerConfig.json'! "
)
def compile_ops(
_md_name: str,
fc_name: Optional[str] = None,
gen_func: Optional[Callable[..., dict[str, Any]]] = None,
gen_fake: Optional[Callable[..., Any]] = None,
):
def decorator(func):
func.arg_checked = False
@functools.wraps(func)
def wrapper(*args, custom_build_args={}, **kwargs):
loadName = fc_name
md_name = _md_name
if fc_name is None:
loadName = func.__name__
try:
module = None
if gen_func is not None:
custom_build_args.update(gen_func(*args, **kwargs))
elif AITER_REBUILD and md_name not in rebuilded_list:
rebuilded_list.append(md_name)
raise ModuleNotFoundError("start rebuild")
if module is None:
try:
module = get_module(md_name)
except Exception as e:
md = custom_build_args.get("md_name", md_name)
module = get_module(md)
except ModuleNotFoundError:
d_args = get_args_of_build(md_name)
d_args.update(custom_build_args)
if d_args.get("skip_if", False):
logger.info(f"skip build [{md_name}] due to skip_if condition")
return None
# update module if we have coustom build
md_name = custom_build_args.get("md_name", md_name)
srcs = d_args["srcs"]
flags_extra_cc = d_args["flags_extra_cc"]
flags_extra_hip = d_args["flags_extra_hip"]
blob_gen_cmd = d_args["blob_gen_cmd"]
extra_include = d_args["extra_include"]
extra_ldflags = d_args["extra_ldflags"]
verbose = d_args["verbose"]
is_python_module = d_args["is_python_module"]
is_standalone = d_args["is_standalone"]
torch_exclude = d_args["torch_exclude"]
hipify = d_args.get("hipify", False)
hip_clang_path = d_args.get("hip_clang_path", None)
prev_hip_clang_path = None
if hip_clang_path is not None and os.path.exists(hip_clang_path):
prev_hip_clang_path = os.environ.get("HIP_CLANG_PATH", None)
os.environ["HIP_CLANG_PATH"] = hip_clang_path
build_module(
md_name,
srcs,
flags_extra_cc,
flags_extra_hip,
blob_gen_cmd,
extra_include,
extra_ldflags,
verbose,
is_python_module,
is_standalone,
torch_exclude,
hipify,
)
if hip_clang_path is not None:
if prev_hip_clang_path is not None:
os.environ["HIP_CLANG_PATH"] = prev_hip_clang_path
else:
os.environ.pop("HIP_CLANG_PATH", None)
if is_python_module:
module = get_module(md_name)
if md_name not in __mds:
__mds[md_name] = module
if isinstance(module, types.ModuleType):
op = getattr(module, loadName)
else:
return None
def check_args():
get_asm_dir()
import inspect
import re
import torch
enum_types = ["ActivationType", "QuantType"]
if not op.__doc__.startswith("Members:"):
doc_str = op.__doc__.split("\n")[0]
doc_str = re.sub(r"<(.*?)\:.*?>", r"\g<1>", doc_str)
doc_str = doc_str.replace("list[", "List[")
doc_str = doc_str.replace("tuple[", "Tuple[")
doc_str = doc_str.replace("collections.abc.Sequence[", "List[")
doc_str = doc_str.replace("typing.SupportsInt", "int")
doc_str = doc_str.replace("typing.SupportsFloat", "float")
# A|None --> Optional[A]
pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None"
doc_str = re.sub(pattern, r"Optional[\1]", doc_str)
for el in enum_types:
doc_str = re.sub(f" aiter.*{el} ", f" {el} ", doc_str)
namespace = {
"List": List,
"Optional": Optional,
"torch": torch,
"typing": typing,
}
exec(
f"from aiter import*\ndef {doc_str}: pass",
namespace,
)
foo = namespace[doc_str.split("(")[0]]
sig = inspect.signature(foo)
func.__signature__ = sig
ann = {k: v.annotation for k, v in sig.parameters.items()}
ann["return"] = sig.return_annotation
callargs = inspect.getcallargs(func, *args, **kwargs)
for el, arg in callargs.items():
expected_type = ann[el]
got_type = type(arg)
origin = typing.get_origin(expected_type)
sub_t = typing.get_args(expected_type)
if origin is None:
if not isinstance(arg, expected_type) and not (
# aiter_enum can be int
any(el in str(expected_type) for el in enum_types)
and isinstance(arg, int)
):
raise TypeError(
f"{loadName}: {el} needs to be {expected_type} but got {got_type}"
)
elif origin is list:
if (
not isinstance(arg, list)
# or not all(isinstance(i, sub_t) for i in arg)
):
raise TypeError(
f"{loadName}: {el} needs to be List[{sub_t}] but got {arg}"
)
elif origin is typing.Union or origin is types.UnionType:
if arg is not None and not isinstance(arg, sub_t):
raise TypeError(
f"{loadName}: {el} needs to be Optional[{sub_t}] but got {arg}"
)
else:
raise TypeError(f"Unsupported type: {expected_type}")
func_hints = typing.get_type_hints(func)
if ann["return"] is None:
func_hints["return"] = None
# if ann != func_hints:
# logger.warning(
# f"type hints mismatch, override to --> {doc_str}"
# )
return True
if not func.arg_checked:
func.arg_checked = check_args()
if AITER_LOG_MORE == 2:
from ..test_common import log_args
log_args(func, *args, **kwargs)
return op(*args, **kwargs)
@torch_compile_guard(device="cuda", gen_fake=gen_fake, calling_func_=func)
def custom_wrapper(*args, **kwargs):
return wrapper(*args, **kwargs)
return custom_wrapper
return decorator
{
"module_aiter_enum": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/aiter_enum_pybind.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"torch_exclude": "False",
"blob_gen_cmd": "''"
},
"module_activation": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/activation_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/activation_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": ["'-DENABLE_FP8'"],
"extra_ldflags": "None",
"extra_include": ["f'{AITER_CSRC_DIR}/include/ck_tile'"],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_custom_all_reduce": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/custom_all_reduce_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/custom_all_reduce.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_moe_sorting": {
"srcs": [
"f'{AITER_CSRC_DIR}/py_itfs_ck/moe_sorting_kernels.cu'",
"f'{AITER_CSRC_DIR}/pybind/moe_sorting_pybind.cu'",
"f'{CK_DIR}/example_hcu/ck_tile/13_moe_sorting/'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{CK_DIR}/example_hcu/ck_tile/13_moe_sorting/'"
],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_moe_sum": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/moe_sum_pybind.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_moe": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/moe_ck_pybind.cu'",
"f'{CK_DIR}/example_hcu/ck_tile/17_fused_moe/instances'",
"f'{CK_DIR}/example_hcu/ck_tile/17_fused_moe/moe_2stage'",
"f'{CK_DIR}/example_hcu/ck_tile/18_moe_quant/instances'",
"f'{AITER_CSRC_DIR}/py_itfs_ck/moe_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{CK_DIR}/example_hcu/ck_tile/17_fused_moe'",
"f'{CK_DIR}/example_hcu/ck_tile/18_moe_quant'"
],
"verbose": "False",
"hipify": "True",
"blob_gen_cmd": "''"
},
"module_moe_utils":{
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/moe_utils_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/topk_softmax_kernels.cu'",
"f'{AITER_CSRC_DIR}/kernels/topk_softmax_kernels_group.cu'",
"f'{AITER_CSRC_DIR}/kernels/moe_fused_gate.cu'",
"f'{AITER_CSRC_DIR}/kernels/moe_align_sum_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": ["'-DENABLE_FP8'"],
"extra_ldflags": "None",
"extra_include": [
"f'{AITER_CSRC_DIR}/include/ck_tile'"
],
"verbose": "False",
"hifify": "True",
"blob_gen_cmd": "''"
},
"module_moe_asm": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/moe_asm_2stages_pybind.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_asm/asm_fmoe_2stage.cpp'",
"f'{AITER_CSRC_DIR}/py_itfs_asm/asm_fmoe_a8.cpp'",
"f'{AITER_CSRC_DIR}/py_itfs_asm/asm_fmoe_solutions.cpp'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"hipify":"True",
"blob_gen_cmd": "''"
},
"module_awq_gemm_asm": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/awq_gemm_asm_pybind.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_asm/asm_gemm_awq.cpp'",
"f'{AITER_CSRC_DIR}/py_itfs_asm/asm_gemm_kernel_config.cpp'"
],
"flags_extra_cc": [
"f'-DAITER_OPT_KERNEL_CONFIG_PATH=\"{AITER_CSRC_DIR}/py_itfs_asm/optKernelManifest.json\"'"
],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{AITER_CSRC_DIR}/py_itfs_asm'"
],
"verbose": "False",
"hipify": "True",
"blob_gen_cmd": "''"
},
"module_awq_dq_asm": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/awq_dq_asm_pybind.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_asm/asm_dq_awq.cpp'"
],
"flags_extra_cc": [
"f'-DAITER_OPT_KERNEL_CONFIG_PATH=\"{AITER_CSRC_DIR}/py_itfs_asm/optKernelManifest.json\"'"
],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{AITER_CSRC_DIR}/py_itfs_asm'"
],
"verbose": "False",
"hipify": "True",
"blob_gen_cmd": "''"
},
"module_norm": {
"srcs": [
"f'{AITER_CSRC_DIR}/py_itfs_ck/norm_kernels.cu'",
"f'{AITER_CSRC_DIR}/pybind/norm_pybind.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{CK_DIR}/example_hcu/ck_tile/02_layernorm2d'"
],
"verbose": "False",
"blob_gen_cmd": "f'{CK_DIR}/example_hcu/ck_tile/02_layernorm2d/generate.py --api fwd --gen_blobs --working_path {{}}'"
},
"module_pos_encoding": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/pos_encoding_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/pos_encoding_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_rmsnorm": {
"srcs": [
"f'{AITER_CSRC_DIR}/kernels/rmsnorm_kernels.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_ck/rmsnorm_ck_kernels.cu'",
"f'{AITER_CSRC_DIR}/pybind/rmsnorm_pybind.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [
"f'{CK_DIR}/example_hcu/ck_tile/10_rmsnorm2d'"
],
"verbose": "False",
"blob_gen_cmd": "f'{CK_DIR}/example_hcu/ck_tile/10_rmsnorm2d/generate.py --api fwd --gen_blobs --working_path {{}}'"
},
"module_aiter_operator": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/aiter_operator_pybind.cu'",
"f'{AITER_CSRC_DIR}/include/binary_operator.cuh'",
"f'{AITER_CSRC_DIR}/kernels/binary_operator.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "f'{AITER_CSRC_DIR}/kernels/generate_binaryop.py --working_path {{}} --optype all --dtypes all'"
},
"module_aiter_unary": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/aiter_unary_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/unary_operator.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_quant": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/quant_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/quant_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [
"'-DENABLE_FP8'"
],
"extra_ldflags": "None",
"extra_include": [
"f'{AITER_CSRC_DIR}/include/ck_tile'"
],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_rope_general_fwd": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/rope_general_fwd_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/rope/rope_common.h'",
"f'{AITER_CSRC_DIR}/kernels/rope/general_fwd_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_rope_general_bwd": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/rope_general_bwd_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/rope/rope_common.h'",
"f'{AITER_CSRC_DIR}/kernels/rope/general_bwd_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_rope_pos_fwd": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/rope_pos_fwd_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/rope/rope_common.h'",
"f'{AITER_CSRC_DIR}/kernels/rope/pos_fwd_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_fused_qk_norm_mrope_cache_quant_shuffle": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/fused_qk_norm_mrope_cache_quant_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/rope/rope_common.h'",
"f'{AITER_CSRC_DIR}/kernels/fused_qk_norm_mrope_cache_quant.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_fused_qk_norm_rope_cache_quant_shuffle": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/fused_qk_norm_rope_cache_quant_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/rope/rope_common.h'",
"f'{AITER_CSRC_DIR}/kernels/fused_qk_norm_rope_cache_quant.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [
"'-DENABLE_FP8'"
],
"extra_ldflags": "None",
"extra_include": [
"f'{AITER_CSRC_DIR}/include/ck_tile'",
"f'{AITER_CSRC_DIR}/include/opus'"
],
"verbose": "False",
"blob_gen_cmd": "''"
},
"module_rocsolgemm": {
"srcs": [
"f'{AITER_GRADLIB_DIR}/csrc/rocsolgemm.cu'"
],
"flags_extra_cc": [
"'-O3'"
],
"flags_extra_hip": [
"'-O3'",
"'-U__CUDA_NO_HALF_OPERATORS__'",
"'-U__CUDA_NO_HALF_CONVERSIONS__'",
"'-ftemplate-depth=1024'"
],
"extra_ldflags": ["'-lrocblas'"],
"extra_include": [
"f'{AITER_GRADLIB_DIR}/include/'"
],
"hipify": "True",
"verbose": "False",
"blob_gen_cmd": "''",
"skip_if": "detect_dtk_env()"
},
"module_hipbsolgemm": {
"srcs": [
"f'{AITER_GRADLIB_DIR}/csrc/hipbsolgemm.cu'"
],
"flags_extra_cc": [
"'-O3'"
],
"flags_extra_hip": [
"'-O3'",
"'-U__CUDA_NO_HALF_OPERATORS__'",
"'-U__CUDA_NO_HALF_CONVERSIONS__'",
"'-ftemplate-depth=1024'",
"'-DENABLE_TORCH_FP8' if hasattr(torch, 'float8_e4m3fn') else '' "
],
"extra_ldflags": ["'-lhipblaslt'"],
"extra_include": [
"f'{AITER_GRADLIB_DIR}/include/'"
],
"hipify": "True",
"verbose": "False",
"blob_gen_cmd": "''",
"skip_if": "detect_dtk_env()"
},
"module_moe_c_kernel": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/moe_c_pybind.cu'",
"f'{MOE_C_DIR}/csrc_for_aiter'",
"f'{AITER_CSRC_DIR}/py_itfs_moe_c/moe_c.cu'"
],
"flags_extra_cc": ["' -mllvm -support-768-vgprs=true -mllvm -disable-machine-sink '"
],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"hipify": "True",
"blob_gen_cmd": "''"
},
"module_topk_plain": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/topk_plain_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/topk_plain_kernels.cu'",
"f'{AITER_CSRC_DIR}/kernels/topk_per_row_kernels.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"hipify": "True",
"blob_gen_cmd": "''"
},
"module_topk_transform": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/topk_transform_pybind.cu'",
"f'{AITER_CSRC_DIR}/kernels/topk_transform.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"hipify": "True",
"blob_gen_cmd": "''"
}
}
# SPDX-License-Identifier: MIT
\ No newline at end of file
# SPDX-License-Identifier: MIT
# mypy: allow-untyped-defs
import collections
Entry = collections.namedtuple("Entry", "version, hash")
def update_hash(seed, value):
# Good old boost::hash_combine
# https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
return seed ^ (hash(value) + 0x9E3779B9 + (seed << 6) + (seed >> 2))
def hash_source_files(hash_value, source_files):
for filename in source_files:
with open(filename, "rb") as file:
hash_value = update_hash(hash_value, file.read())
return hash_value
def hash_build_arguments(hash_value, build_arguments):
for group in build_arguments:
if group:
for argument in group:
hash_value = update_hash(hash_value, argument)
return hash_value
class ExtensionVersioner:
def __init__(self):
self.entries = {}
def get_version(self, name):
entry = self.entries.get(name)
return None if entry is None else entry.version
def bump_version_if_changed(
self,
name,
source_files,
build_arguments,
build_directory,
with_cuda,
is_python_module,
is_standalone,
):
hash_value = 0
hash_value = hash_source_files(hash_value, source_files)
hash_value = hash_build_arguments(hash_value, build_arguments)
hash_value = update_hash(hash_value, build_directory)
hash_value = update_hash(hash_value, with_cuda)
hash_value = update_hash(hash_value, is_python_module)
hash_value = update_hash(hash_value, is_standalone)
entry = self.entries.get(name)
if entry is None:
self.entries[name] = entry = Entry(0, hash_value)
elif hash_value != entry.hash:
self.entries[name] = entry = Entry(entry.version + 1, hash_value)
return entry.version
# SPDX-License-Identifier: MIT
import os
import functools
import subprocess
@functools.lru_cache(maxsize=1)
def get_gfx():
gfx = os.getenv("GPU_ARCHS", "native")
if gfx == "native":
try:
result = subprocess.run(
["rocminfo"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
output = result.stdout
for line in output.split("\n"):
if "gfx" in line.lower():
return line.split(":")[-1].strip()
except Exception as e:
raise RuntimeError(f"Get GPU arch from rcominfo failed {str(e)}")
return gfx
@functools.lru_cache(maxsize=1)
def get_cu_num():
import torch
device = torch.cuda.current_device()
cu_num = torch.cuda.get_device_properties(device).multi_processor_count
return cu_num
# SPDX-License-Identifier: MIT
# This file origins from pytorch:
# https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py
# We make slight changes to enable ninja response file
# mypy: allow-untyped-defs
import copy
import importlib
import importlib.abc
import os
import re
import shlex
import shutil
import subprocess
import sys
import sysconfig
import warnings
from typing import Dict, List, Optional, Tuple, Union
import setuptools
from _cpp_extension_versioner import ExtensionVersioner
from file_baton import FileBaton
from hipify import hipify_python
from hipify.hipify_python import GeneratedFileCleaner
from packaging.version import Version
from setuptools.command.build_ext import build_ext
IS_WINDOWS = sys.platform == "win32"
IS_LINUX = sys.platform.startswith("linux")
LIB_EXT = ".so"
EXEC_EXT = ""
CLIB_PREFIX = "lib"
CLIB_EXT = ".so"
SHARED_FLAG = "-shared"
SUBPROCESS_DECODE_ARGS = ()
MINIMUM_GCC_VERSION = (5, 0, 0)
MINIMUM_MSVC_VERSION = (19, 0, 24215)
VersionRange = Tuple[Tuple[int, ...], Tuple[int, ...]]
VersionMap = Dict[str, VersionRange]
# The following values were taken from the following GitHub gist that
# summarizes the minimum valid major versions of g++/clang++ for each supported
# CUDA version: https://gist.github.com/ax3l/9489132
# Or from include/crt/host_config.h in the CUDA SDK
# The second value is the exclusive(!) upper bound, i.e. min <= version < max
MINIMUM_CLANG_VERSION = (3, 3, 0)
__all__ = [
"check_compiler_ok_for_platform",
"get_compiler_abi_compatibility_and_version",
"BuildExtension",
"CppExtension",
"CUDAExtension",
"include_paths",
"library_paths",
"load",
"is_ninja_available",
"verify_ninja_availability",
"get_cxx_compiler",
"check_compiler_is_gcc",
]
def executable_path(executable: str) -> str:
"""
Return the path to the executable.
Args:
executable (str): The name of the executable.
Returns:
The path to the executable.
"""
env_override = os.environ.get(executable.upper())
candidate_paths = [env_override, shutil.which(executable)]
home = _find_rocm_home()
if home:
candidate_paths.extend(
[
os.path.join(home, "bin", executable),
os.path.join(home, "hip", "bin", executable),
os.path.join(home, "llvm", "bin", executable),
]
)
path = next(
(
os.path.realpath(candidate)
for candidate in candidate_paths
if candidate and os.path.exists(candidate)
),
None,
)
assert path is not None, (
f"Could not find {executable} in PATH or ROCM_HOME({home})"
)
return os.path.realpath(path)
def get_hip_version():
try:
hipconfig = executable_path("hipconfig")
output = subprocess.check_output([hipconfig, "--version"], text=True)
return output
except Exception:
raise RuntimeError("ROCm version file not found")
def _find_rocm_home() -> Optional[str]:
"""Find the ROCm install path."""
# Guess #1
rocm_home = os.environ.get("ROCM_HOME") or os.environ.get("ROCM_PATH")
if rocm_home is None:
# Guess #2
hipcc_path = shutil.which("hipcc")
if hipcc_path is not None:
rocm_home = os.path.dirname(os.path.dirname(os.path.realpath(hipcc_path)))
# can be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipcc
if os.path.basename(rocm_home) == "hip":
rocm_home = os.path.dirname(rocm_home)
else:
# Guess #3
fallback_path = "/opt/rocm"
if os.path.exists(fallback_path):
rocm_home = fallback_path
if rocm_home is None:
print(
f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'", file=sys.stderr
)
return rocm_home
def _join_rocm_home(*paths) -> str:
"""
Join paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set.
This is basically a lazy way of raising an error for missing $ROCM_HOME
only once we need to get any ROCm-specific path.
"""
if ROCM_HOME is None:
raise OSError(
"ROCM_HOME environment variable is not set. "
"Please set it to your ROCm install root."
)
elif IS_WINDOWS:
raise OSError(
"Building PyTorch extensions using " "ROCm and Windows is not supported."
)
return os.path.join(ROCM_HOME, *paths)
ABI_INCOMPATIBILITY_WARNING = """
!! WARNING !!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler ({}) may be ABI-incompatible with PyTorch!
Please use a compiler that is ABI-compatible with GCC 5.0 and above.
See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.
See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
for instructions on how to install GCC 5 or higher.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!! WARNING !!
"""
WRONG_COMPILER_WARNING = """
!! WARNING !!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
built with for this platform, which is {pytorch_compiler} on {platform}. Please
use {pytorch_compiler} to to compile your extension. Alternatively, you may
compile PyTorch from source using {user_compiler}, and then you can also use
{user_compiler} to compile your extension.
See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
with compiling PyTorch from source.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!! WARNING !!
"""
HIP_VERSION = get_hip_version()
ROCM_HOME = _find_rocm_home()
HIP_HOME = _join_rocm_home("hip") if ROCM_HOME else None
IS_HIP_EXTENSION = (
True if ((ROCM_HOME is not None) and (HIP_VERSION is not None)) else False
)
ROCM_VERSION = None
if HIP_VERSION is not None:
ROCM_VERSION = tuple(int(v) for v in HIP_VERSION.split(".")[:2])
# PyTorch releases have the version pattern major.minor.patch, whereas when
# PyTorch is built from source, we append the git commit hash, which gives
# it the below pattern.
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r"\d+\.\d+\.\d+\w+\+\w+")
COMMON_MSVC_FLAGS = [
"/MD",
"/wd4819",
"/wd4251",
"/wd4244",
"/wd4267",
"/wd4275",
"/wd4018",
"/wd4190",
"/wd4624",
"/wd4067",
"/wd4068",
"/EHsc",
]
MSVC_IGNORE_CUDAFE_WARNINGS = [
"base_class_has_different_dll_interface",
"field_without_dll_interface",
"dll_interface_conflict_none_assumed",
"dll_interface_conflict_dllexport_assumed",
]
COMMON_NVCC_FLAGS = [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"--expt-relaxed-constexpr",
]
COMMON_HIP_FLAGS = [
"-fPIC",
"-D__HIP_PLATFORM_AMD__=1",
"-DUSE_ROCM=1",
"-DHIPBLAS_V2",
]
COMMON_HIPCC_FLAGS = [
"-DCUDA_HAS_FP16=1",
"-D__HIP_NO_HALF_OPERATORS__=1",
"-D__HIP_NO_HALF_CONVERSIONS__=1",
"-mcmodel=large",
"-fno-unique-section-names",
"-ffunction-sections",
"-fdata-sections",
]
if not int(os.environ.get("AITER_SYMBOL_VISIBLE", "0")):
COMMON_HIPCC_FLAGS.extend(["-fvisibility=hidden", "-fvisibility-inlines-hidden"])
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
PLAT_TO_VCVARS = {
"win32": "x86",
"win-amd64": "x86_amd64",
}
def get_cxx_compiler():
return os.environ.get("CXX", "c++")
def _is_binary_build() -> bool:
import torch
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
def _accepted_compilers_for_platform() -> List[str]:
# gnu-c++ and gnu-cc are the conda gcc compilers
return ["g++", "gcc", "gnu-c++", "gnu-cc", "clang++", "clang"]
def _maybe_write(filename, new_content):
r"""
Equivalent to writing the content into the file but will not touch the file
if it already had the right content (to avoid triggering recompile).
"""
if os.path.exists(filename):
with open(filename) as f:
content = f.read()
if content == new_content:
# The file already contains the right thing!
return
with open(filename, "w") as source_file:
source_file.write(new_content)
def check_compiler_ok_for_platform(compiler: str) -> bool:
"""
Verify that the compiler is the expected one for the current platform.
Args:
compiler (str): The compiler executable to check.
Returns:
True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
and always True for Windows.
"""
compiler_path = os.path.realpath(shutil.which(compiler))
if not compiler_path:
return False
# Check the compiler name
if any(name in compiler_path for name in _accepted_compilers_for_platform()):
return True
# If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag
env = os.environ.copy()
env["LC_ALL"] = "C" # Don't localize output
version_string = subprocess.check_output(
[compiler, "-v"], stderr=subprocess.STDOUT, env=env
).decode(*SUBPROCESS_DECODE_ARGS)
if IS_LINUX:
# Check for 'gcc' or 'g++' for sccache wrapper
pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
results = re.findall(pattern, version_string)
if len(results) != 1:
# Clang is also a supported compiler on Linux
# Though on Ubuntu it's sometimes called "Ubuntu clang version"
return "clang version" in version_string
compiler_path = os.path.realpath(results[0].strip())
# On RHEL/CentOS c++ is a gcc compiler wrapper
if os.path.basename(compiler_path) == "c++" and "gcc version" in version_string:
return True
return any(name in compiler_path for name in _accepted_compilers_for_platform())
return False
def get_compiler_abi_compatibility_and_version(
compiler, torch_exclude=False
) -> Tuple[bool, Version]:
"""
Determine if the given compiler is ABI-compatible with PyTorch alongside its version.
Args:
compiler (str): The compiler executable name to check (e.g. ``g++``).
Must be executable in a shell process.
Returns:
A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch,
followed by a `Version` string that contains the compiler version separated by dots.
"""
if not torch_exclude:
if not _is_binary_build():
return (True, Version("0.0.0"))
if os.environ.get("TORCH_DONT_CHECK_COMPILER_ABI") in [
"ON",
"1",
"YES",
"TRUE",
"Y",
]:
return (True, Version("0.0.0"))
# First check if the compiler is one of the expected ones for the particular platform.
if not check_compiler_ok_for_platform(compiler):
warnings.warn(
WRONG_COMPILER_WARNING.format(
user_compiler=compiler,
pytorch_compiler=_accepted_compilers_for_platform()[0],
platform=sys.platform,
)
)
return (False, Version("0.0.0"))
try:
if IS_LINUX:
minimum_required_version = MINIMUM_GCC_VERSION
versionstr = subprocess.check_output(
[compiler, "-dumpfullversion", "-dumpversion"]
)
match = re.search(
r"(\d+)\.(\d+)\.(\d+)",
versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip(),
)
version = ["0", "0", "0"] if match is None else list(match.groups())
except Exception:
_, error, _ = sys.exc_info()
warnings.warn(f"Error checking compiler version for {compiler}: {error}")
return (False, Version("0.0.0"))
if tuple(map(int, version)) >= minimum_required_version:
return (True, Version(".".join(version)))
compiler = f'{compiler} {".".join(version)}'
warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))
return (False, Version(".".join(version)))
class BuildExtension(build_ext):
"""
A custom :mod:`setuptools` build extension .
This :class:`setuptools.build_ext` subclass takes care of passing the
minimum required compiler flags (e.g. ``-std=c++20``) as well as mixed
C++/CUDA compilation (and support for CUDA files in general).
When using :class:`BuildExtension`, it is allowed to supply a dictionary
for ``extra_compile_args`` (rather than the usual list) that maps from
languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to
supply to the compiler. This makes it possible to supply different flags to
the C++ and CUDA compiler during mixed compilation.
``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we
attempt to build using the Ninja backend. Ninja greatly speeds up
compilation compared to the standard ``setuptools.build_ext``.
Fallbacks to the standard distutils backend if Ninja is not available.
.. note::
By default, the Ninja backend uses #CPUS + 2 workers to build the
extension. This may use up too many resources on some systems. One
can control the number of workers by setting the `MAX_JOBS` environment
variable to a non-negative number.
"""
@classmethod
def with_options(cls, **options):
"""Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options."""
class cls_with_options(cls): # type: ignore[misc, valid-type]
def __init__(self, *args, **kwargs):
kwargs.update(options)
super().__init__(*args, **kwargs)
return cls_with_options
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
self.use_ninja = kwargs.get("use_ninja", True)
if self.use_ninja:
# Test if we can use ninja. Fallback otherwise.
msg = (
"Attempted to use ninja as the BuildExtension backend but "
"{}. Falling back to using the slow distutils backend."
)
if not is_ninja_available():
warnings.warn(msg.format("we could not find ninja."))
self.use_ninja = False
def finalize_options(self) -> None:
super().finalize_options()
if self.use_ninja:
self.force = True
def build_extensions(self) -> None:
import torch
cuda_ext = False
extension_iter = iter(self.extensions)
extension = next(extension_iter, None)
while not cuda_ext and extension:
for source in extension.sources:
_, ext = os.path.splitext(source)
if ext == ".cu":
cuda_ext = True
break
extension = next(extension_iter, None)
for extension in self.extensions:
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict. Otherwise, default torch flags do
# not get passed. Necessary when only one of 'cxx' and 'nvcc' is
# passed to extra_compile_args in CUDAExtension, i.e.
# CUDAExtension(..., extra_compile_args={'cxx': [...]})
# or
# CUDAExtension(..., extra_compile_args={'nvcc': [...]})
if isinstance(extension.extra_compile_args, dict):
for ext in ["cxx", "nvcc"]:
if ext not in extension.extra_compile_args:
extension.extra_compile_args[ext] = []
self._add_compile_flag(extension, "-DTORCH_API_INCLUDE_EXTENSION_H")
# See note [Pybind11 ABI constants]
for name in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
val = getattr(torch._C, f"_PYBIND11_{name}")
if val is not None:
self._add_compile_flag(extension, f'-DPYBIND11_{name}="{val}"')
self._define_torch_extension_name(extension)
self._add_gnu_cpp_abi_flag(extension)
if "nvcc_dlink" in extension.extra_compile_args:
assert (
self.use_ninja
), f"With dlink=True, ninja is required to build cuda extension {extension.name}."
# Register .cu, .cuh, .hip, and .mm as valid source extensions.
self.compiler.src_extensions += [".cu", ".cuh", ".hip"]
if torch.backends.mps.is_built():
self.compiler.src_extensions += [".mm"]
# Save the original _compile method for later.
if self.compiler.compiler_type == "msvc":
self.compiler._cpp_extensions += [".cu", ".cuh"]
original_compile = self.compiler.compile
original_spawn = self.compiler.spawn
else:
original_compile = self.compiler._compile
def append_std17_if_no_std_present(cflags) -> None:
# NVCC does not allow multiple -std to be passed, so we avoid
# overriding the option if the user explicitly passed it.
cpp_format_prefix = (
"/{}:" if self.compiler.compiler_type == "msvc" else "-{}="
)
cpp_flag_prefix = cpp_format_prefix.format("std")
cpp_flag = cpp_flag_prefix + "c++20"
if not any(flag.startswith(cpp_flag_prefix) for flag in cflags):
cflags.append(cpp_flag)
# NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid
# overriding the option if the user explicitly passed it.
_ccbin = os.getenv("CC")
if _ccbin is not None and not any(
flag.startswith(("-ccbin", "--compiler-bindir")) for flag in cflags
):
cflags.extend(["-ccbin", _ccbin])
return cflags
def convert_to_absolute_paths_inplace(paths):
# Helper function. See Note [Absolute include_dirs]
if paths is not None:
for i in range(len(paths)):
if not os.path.isabs(paths[i]):
paths[i] = os.path.abspath(paths[i])
def unix_wrap_single_compile(
obj, src, ext, cc_args, extra_postargs, pp_opts
) -> None:
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
try:
original_compiler = self.compiler.compiler_so
if _is_cuda_file(src):
nvcc = [executable_path("hipcc")]
self.compiler.set_executable("compiler_so", nvcc)
if isinstance(cflags, dict):
cflags = cflags["nvcc"]
cflags = COMMON_HIPCC_FLAGS + cflags + _get_rocm_arch_flags(cflags)
elif isinstance(cflags, dict):
cflags = cflags["cxx"]
if IS_HIP_EXTENSION:
cflags = COMMON_HIP_FLAGS + cflags
append_std17_if_no_std_present(cflags)
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
finally:
# Put the original compiler back in place.
self.compiler.set_executable("compiler_so", original_compiler)
def unix_wrap_ninja_compile(
sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
depends=None,
):
r"""Compiles sources by outputting a ninja file and running it."""
# NB: I copied some lines from self.compiler (which is an instance
# of distutils.UnixCCompiler). See the following link.
# https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567
# This can be fragile, but a lot of other repos also do this
# (see https://github.com/search?q=_setup_compile&type=Code)
# so it is probably OK; we'll also get CI signal if/when
# we update our python version (which is when distutils can be
# upgraded)
# Use absolute path for output_dir so that the object file paths
# (`objects`) get generated with absolute paths.
output_dir = os.path.abspath(output_dir)
# See Note [Absolute include_dirs]
convert_to_absolute_paths_inplace(self.compiler.include_dirs)
_, objects, extra_postargs, pp_opts, _ = self.compiler._setup_compile(
output_dir, macros, include_dirs, sources, depends, extra_postargs
)
common_cflags = self.compiler._get_cc_args(pp_opts, debug, extra_preargs)
extra_cc_cflags = self.compiler.compiler_so[1:]
with_cuda = any(map(_is_cuda_file, sources))
# extra_postargs can be either:
# - a dict mapping cxx/nvcc to extra flags
# - a list of extra flags.
if isinstance(extra_postargs, dict):
post_cflags = extra_postargs["cxx"]
else:
post_cflags = list(extra_postargs)
if IS_HIP_EXTENSION:
post_cflags = COMMON_HIP_FLAGS + post_cflags
append_std17_if_no_std_present(post_cflags)
cuda_post_cflags = None
cuda_cflags = None
if with_cuda:
cuda_cflags = common_cflags
if isinstance(extra_postargs, dict):
cuda_post_cflags = extra_postargs["nvcc"]
else:
cuda_post_cflags = list(extra_postargs)
cuda_post_cflags = cuda_post_cflags + _get_rocm_arch_flags(
cuda_post_cflags
)
cuda_post_cflags = (
COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS + cuda_post_cflags
)
append_std17_if_no_std_present(cuda_post_cflags)
cuda_cflags = [shlex.quote(f) for f in cuda_cflags]
cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags]
_write_ninja_file_and_compile_objects(
sources=sources,
objects=objects,
cflags=[shlex.quote(f) for f in extra_cc_cflags + common_cflags],
post_cflags=[shlex.quote(f) for f in post_cflags],
cuda_cflags=cuda_cflags,
cuda_post_cflags=cuda_post_cflags,
cuda_dlink_post_cflags=None,
build_directory=output_dir,
verbose=True,
with_cuda=with_cuda,
)
# Return *all* object filenames, not just the ones we just built.
return objects
# Monkey-patch the _compile or compile method.
# https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511
if self.compiler.compiler_type == "msvc":
print("currently only support unix")
# if self.use_ninja:
# self.compiler.compile = win_wrap_ninja_compile
# else:
# self.compiler.compile = win_wrap_single_compile
else:
if self.use_ninja:
self.compiler.compile = unix_wrap_ninja_compile
else:
self.compiler._compile = unix_wrap_single_compile
build_ext.build_extensions(self)
def get_ext_filename(self, ext_name):
# Get the original shared library name. For Python 3, this name will be
# suffixed with "<SOABI>.so", where <SOABI> will be something like
# cpython-37m-x86_64-linux-gnu.
ext_filename = super().get_ext_filename(ext_name)
# If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI
# component. This makes building shared libraries with setuptools that
# aren't Python modules nicer.
if self.no_python_abi_suffix:
# The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"].
ext_filename_parts = ext_filename.split(".")
# Omit the second to last element.
without_abi = ext_filename_parts[:-2] + ext_filename_parts[-1:]
ext_filename = ".".join(without_abi)
return ext_filename
def _add_compile_flag(self, extension, flag):
extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args)
if isinstance(extension.extra_compile_args, dict):
for args in extension.extra_compile_args.values():
args.append(flag)
else:
extension.extra_compile_args.append(flag)
def _define_torch_extension_name(self, extension):
# pybind11 doesn't support dots in the names
# so in order to support extensions in the packages
# like torch._C, we take the last part of the string
# as the library name
names = extension.name.split(".")
name = names[-1]
define = f"-DTORCH_EXTENSION_NAME={name}"
self._add_compile_flag(extension, define)
def _add_gnu_cpp_abi_flag(self, extension):
import torch
# use the same CXX ABI as what PyTorch was compiled with
self._add_compile_flag(
extension,
"-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI)),
)
def CppExtension(name, sources, *args, **kwargs):
"""
Create a :class:`setuptools.Extension` for C++.
Convenience method that creates a :class:`setuptools.Extension` with the
bare minimum (but often sufficient) arguments to build a C++ extension.
All arguments are forwarded to the :class:`setuptools.Extension`
constructor. Full list arguments can be found at
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
Example:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
>>> from setuptools import setup
>>> from torch.utils.cpp_extension import BuildExtension, CppExtension
>>> setup(
... name='extension',
... ext_modules=[
... CppExtension(
... name='extension',
... sources=['extension.cpp'],
... extra_compile_args=['-g'],
... extra_link_flags=['-Wl,--no-as-needed', '-lm'])
... ],
... cmdclass={
... 'build_ext': BuildExtension
... })
"""
include_dirs = kwargs.get("include_dirs", [])
include_dirs += include_paths()
kwargs["include_dirs"] = include_dirs
library_dirs = kwargs.get("library_dirs", [])
library_dirs += library_paths()
kwargs["library_dirs"] = library_dirs
libraries = kwargs.get("libraries", [])
libraries.append("c10")
libraries.append("torch")
libraries.append("torch_cpu")
libraries.append("torch_python")
kwargs["libraries"] = libraries
kwargs["language"] = "c++"
return setuptools.Extension(name, sources, *args, **kwargs)
def CUDAExtension(name, sources, *args, **kwargs):
"""
Create a :class:`setuptools.Extension` for CUDA/C++.
Convenience method that creates a :class:`setuptools.Extension` with the
bare minimum (but often sufficient) arguments to build a CUDA/C++
extension. This includes the CUDA include path, library path and runtime
library.
All arguments are forwarded to the :class:`setuptools.Extension`
constructor. Full list arguments can be found at
https://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference
Example:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
>>> from setuptools import setup
>>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
>>> setup(
... name='cuda_extension',
... ext_modules=[
... CUDAExtension(
... name='cuda_extension',
... sources=['extension.cpp', 'extension_kernel.cu'],
... extra_compile_args={'cxx': ['-g'],
... 'nvcc': ['-O2']},
... extra_link_flags=['-Wl,--no-as-needed', '-lcuda'])
... ],
... cmdclass={
... 'build_ext': BuildExtension
... })
Compute capabilities:
By default the extension will be compiled to run on all archs of the cards visible during the
building process of the extension, plus PTX. If down the road a new card is installed the
extension may need to be recompiled. If a visible card has a compute capability (CC) that's
newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch
will make nvcc fall back to building kernels with the newest version of PTX your nvcc does
support (see below for details on PTX).
You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which
CCs you want the extension to support:
``TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py``
``TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py``
The +PTX option causes extension kernel binaries to include PTX instructions for the specified
CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >=
the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with
CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to
provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on
those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better
off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6,
"8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but
"8.0 8.6" would be better.
Note that while it's possible to include all supported archs, the more archs get included the
slower the building process will be, as it will build a separate kernel image for each arch.
Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows.
To workaround the issue, move python binding logic to pure C++ file.
Example use:
#include <ATen/ATen.h>
at::Tensor SigmoidAlphaBlendForwardCuda(....)
Instead of:
#include <torch/extension.h>
torch::Tensor SigmoidAlphaBlendForwardCuda(...)
Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460
Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48
Relocatable device code linking:
If you want to reference device symbols across compilation units (across object files),
the object files need to be built with `relocatable device code` (-rdc=true or -dc).
An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore.
`Relocatable device code` is less optimized so it needs to be used only on object files that need it.
Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
help reduce the protentional perf degradation of `-rdc`.
Note that it needs to be used at both steps to be useful.
If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
There is also a case where `-dlink` is used without `-rdc`:
when an extension is linked against a static lib containing rdc-compiled objects
like the [SHMEM library].
Note: Ninja is required to build a CUDA Extension with RDC linking.
Example:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
>>> CUDAExtension(
... name='cuda_extension',
... sources=['extension.cpp', 'extension_kernel.cu'],
... dlink=True,
... dlink_libraries=["dlink_lib"],
... extra_compile_args={'cxx': ['-g'],
... 'nvcc': ['-O2', '-rdc=true']})
"""
library_dirs = kwargs.get("library_dirs", [])
library_dirs += library_paths(cuda=True)
kwargs["library_dirs"] = library_dirs
libraries = kwargs.get("libraries", [])
libraries.append("c10")
libraries.append("torch")
libraries.append("torch_cpu")
libraries.append("torch_python")
if IS_HIP_EXTENSION:
libraries.append("amdhip64")
libraries.append("c10_hip")
libraries.append("torch_hip")
else:
libraries.append("cudart")
libraries.append("c10_cuda")
libraries.append("torch_cuda")
kwargs["libraries"] = libraries
include_dirs = kwargs.get("include_dirs", [])
if IS_HIP_EXTENSION:
build_dir = os.getcwd()
hipify_result = hipify_python.hipify(
project_directory=build_dir,
output_directory=build_dir,
header_include_dirs=include_dirs,
includes=[os.path.join(build_dir, "*")], # limit scope to build_dir only
extra_files=[os.path.abspath(s) for s in sources],
show_detailed=True,
is_pytorch_extension=True,
hipify_extra_files_only=True, # don't hipify everything in includes path
)
hipified_sources = set()
for source in sources:
s_abs = os.path.abspath(source)
hipified_s_abs = (
hipify_result[s_abs].hipified_path
if (
s_abs in hipify_result
and hipify_result[s_abs].hipified_path is not None
)
else s_abs
)
# setup() arguments must *always* be /-separated paths relative to the setup.py directory,
# *never* absolute paths
hipified_sources.add(os.path.relpath(hipified_s_abs, build_dir))
sources = list(hipified_sources)
include_dirs += include_paths(cuda=True)
kwargs["include_dirs"] = include_dirs
kwargs["language"] = "c++"
dlink_libraries = kwargs.get("dlink_libraries", [])
dlink = kwargs.get("dlink", False) or dlink_libraries
if dlink:
extra_compile_args = kwargs.get("extra_compile_args", {})
extra_compile_args_dlink = extra_compile_args.get("nvcc_dlink", [])
extra_compile_args_dlink += ["-dlink"]
extra_compile_args_dlink += [f"-L{x}" for x in library_dirs]
extra_compile_args_dlink += [f"-l{x}" for x in dlink_libraries]
extra_compile_args["nvcc_dlink"] = extra_compile_args_dlink
kwargs["extra_compile_args"] = extra_compile_args
return setuptools.Extension(name, sources, *args, **kwargs)
def include_paths(cuda: bool = False) -> List[str]:
"""
Get the include paths required to build a C++ or CUDA extension.
Args:
cuda: If `True`, includes CUDA-specific include paths.
Returns:
A list of include path strings.
"""
import torch
_TORCH_PATH = os.path.join(os.path.dirname(torch.__file__))
lib_include = os.path.join(_TORCH_PATH, "include")
paths = [
lib_include,
# Remove this once torch/torch.h is officially no longer supported for C++ extensions.
os.path.join(lib_include, "torch", "csrc", "api", "include"),
# Some internal (old) Torch headers don't properly prefix their includes,
# so we need to pass -Itorch/lib/include/TH as well.
os.path.join(lib_include, "TH"),
os.path.join(lib_include, "THC"),
]
if cuda and IS_HIP_EXTENSION:
paths.append(os.path.join(lib_include, "THH"))
paths.append(_join_rocm_home("include"))
return paths
def library_paths(cuda: bool = False) -> List[str]:
"""
Get the library paths required to build a C++ or CUDA extension.
Args:
cuda: If `True`, includes CUDA-specific library paths.
Returns:
A list of library path strings.
"""
# We need to link against libtorch.so
import torch
_TORCH_PATH = os.path.join(os.path.dirname(torch.__file__))
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, "lib")
paths = [TORCH_LIB_PATH]
if cuda and IS_HIP_EXTENSION:
lib_dir = "lib"
paths.append(_join_rocm_home(lib_dir))
if HIP_HOME is not None:
paths.append(os.path.join(HIP_HOME, "lib"))
return paths
def load(
name,
sources: Union[str, List[str]],
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
build_directory=None,
verbose=False,
with_cuda: Optional[bool] = None,
is_python_module=True,
is_standalone=False,
keep_intermediates=True,
torch_exclude=False,
):
"""
Load a PyTorch C++ extension just-in-time (JIT).
To load an extension, a Ninja build file is emitted, which is used to
compile the given sources into a dynamic library. This library is
subsequently loaded into the current Python process as a module and
returned from this function, ready for use.
By default, the directory to which the build file is emitted and the
resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where
``<tmp>`` is the temporary folder on the current platform and ``<name>``
the name of the extension. This location can be overridden in two ways.
First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it
replaces ``<tmp>/torch_extensions`` and all extensions will be compiled
into subfolders of this directory. Second, if the ``build_directory``
argument to this function is supplied, it overrides the entire path, i.e.
the library will be compiled into that folder directly.
To compile the sources, the default system compiler (``c++``) is used,
which can be overridden by setting the ``CXX`` environment variable. To pass
additional arguments to the compilation process, ``extra_cflags`` or
``extra_ldflags`` can be provided. For example, to compile your extension
with optimizations, pass ``extra_cflags=['-O3']``. You can also use
``extra_cflags`` to pass further include directories.
CUDA support with mixed compilation is provided. Simply pass CUDA source
files (``.cu`` or ``.cuh``) along with other sources. Such files will be
detected and compiled with nvcc rather than the C++ compiler. This includes
passing the CUDA lib64 directory as a library directory, and linking
``cudart``. You can pass additional flags to nvcc via
``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
heuristics for finding the CUDA install directory are used, which usually
work fine. If not, setting the ``CUDA_HOME`` environment variable is the
safest option.
Args:
name: The name of the extension to build. This MUST be the same as the
name of the pybind11 module!
sources: A list of relative or absolute paths to C++ source files.
extra_cflags: optional list of compiler flags to forward to the build.
extra_cuda_cflags: optional list of compiler flags to forward to nvcc
when building CUDA sources.
extra_ldflags: optional list of linker flags to forward to the build.
extra_include_paths: optional list of include directories to forward
to the build.
build_directory: optional path to use as build workspace.
verbose: If ``True``, turns on verbose logging of load steps.
with_cuda: Determines whether CUDA headers and libraries are added to
the build. If set to ``None`` (default), this value is
automatically determined based on the existence of ``.cu`` or
``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
and libraries to be included.
is_python_module: If ``True`` (default), imports the produced shared
library as a Python module. If ``False``, behavior depends on
``is_standalone``.
is_standalone: If ``False`` (default) loads the constructed extension
into the process as a plain dynamic library. If ``True``, build a
standalone executable.
Returns:
If ``is_python_module`` is ``True``:
Returns the loaded PyTorch extension as a Python module.
If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``:
Returns nothing. (The shared library is loaded into the process as
a side effect.)
If ``is_standalone`` is ``True``.
Return the path to the executable. (On Windows, TORCH_LIB_PATH is
added to the PATH environment variable as a side effect.)
Example:
>>> # xdoctest: +SKIP
>>> from torch.utils.cpp_extension import load
>>> module = load(
... name='extension',
... sources=['extension.cpp', 'extension_kernel.cu'],
... extra_cflags=['-O2'],
... verbose=True)
"""
return _jit_compile(
name,
[sources] if isinstance(sources, str) else sources,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
build_directory,
verbose,
with_cuda,
is_python_module,
is_standalone,
keep_intermediates=keep_intermediates,
torch_exclude=torch_exclude,
)
def _get_pybind11_abi_build_flags():
# Note [Pybind11 ABI constants]
#
# Pybind11 before 2.4 used to build an ABI strings using the following pattern:
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"
# Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:
# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"
#
# This was done in order to further narrow down the chances of compiler ABI incompatibility
# that can cause a hard to debug segfaults.
# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties
# captured during PyTorch native library compilation in torch/csrc/Module.cpp
import torch
abi_cflags = []
for pname in ["COMPILER_TYPE", "STDLIB", "BUILD_ABI"]:
pval = getattr(torch._C, f"_PYBIND11_{pname}", None)
if pval is not None:
abi_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')
return abi_cflags
def _get_glibcxx_abi_build_flags():
import torch
glibcxx_abi_cflags = [
"-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))
]
return glibcxx_abi_cflags
def check_compiler_is_gcc(compiler):
if not IS_LINUX:
return False
env = os.environ.copy()
env["LC_ALL"] = "C" # Don't localize output
try:
version_string = subprocess.check_output(
[compiler, "-v"], stderr=subprocess.STDOUT, env=env
).decode(*SUBPROCESS_DECODE_ARGS)
except Exception:
try:
version_string = subprocess.check_output(
[compiler, "--version"], stderr=subprocess.STDOUT, env=env
).decode(*SUBPROCESS_DECODE_ARGS)
except Exception:
return False
# Check for 'gcc' or 'g++' for sccache wrapper
pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE)
results = re.findall(pattern, version_string)
if len(results) != 1:
return False
compiler_path = os.path.realpath(results[0].strip())
# On RHEL/CentOS c++ is a gcc compiler wrapper
if os.path.basename(compiler_path) == "c++" and "gcc version" in version_string:
return True
return False
def _jit_compile(
name,
sources,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
build_directory: str,
verbose: bool,
with_cuda: Optional[bool],
is_python_module,
is_standalone,
keep_intermediates=True,
torch_exclude=False,
hipify=True,
) -> None:
if is_python_module and is_standalone:
raise ValueError(
"`is_python_module` and `is_standalone` are mutually exclusive."
)
if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
old_version = JIT_EXTENSION_VERSIONER.get_version(name)
version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
name,
sources,
build_arguments=[
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
],
build_directory=build_directory,
with_cuda=with_cuda,
is_python_module=is_python_module,
is_standalone=is_standalone,
)
if version > 0:
if version != old_version and verbose:
print(
f"The input conditions for extension module {name} have changed. "
+ f"Bumping to version {version} and re-building as {name}_v{version}...",
file=sys.stderr,
)
name = f"{name}_v{version}"
baton = FileBaton(os.path.join(build_directory, "lock"))
if baton.try_acquire():
try:
if version != old_version:
with GeneratedFileCleaner(
keep_intermediates=keep_intermediates
) as clean_ctx:
torch_path = os.path.join("")
if not torch_exclude:
import torch
_TORCH_PATH = os.path.join(os.path.dirname(torch.__file__))
torch_path = os.path.join(_TORCH_PATH, "*")
if IS_HIP_EXTENSION and with_cuda and hipify:
hipify_result = hipify_python.hipify(
project_directory=build_directory,
output_directory=build_directory,
header_include_dirs=(
extra_include_paths
if extra_include_paths is not None
else []
),
extra_files=[os.path.abspath(s) for s in sources],
ignores=[
_join_rocm_home("*"),
torch_path,
], # no need to hipify ROCm or PyTorch headers
show_detailed=verbose,
show_progress=verbose,
is_pytorch_extension=True,
hipify_extra_files_only=True, # don't hipify everything in includes path
clean_ctx=clean_ctx,
)
hipified_sources = set()
for source in sources:
s_abs = os.path.abspath(source)
hipified_sources.add(
hipify_result[s_abs].hipified_path
if s_abs in hipify_result
else s_abs
)
sources = list(hipified_sources)
_write_ninja_file_and_build_library(
name=name,
sources=sources,
extra_cflags=extra_cflags or [],
extra_cuda_cflags=extra_cuda_cflags or [],
extra_ldflags=extra_ldflags or [],
extra_include_paths=extra_include_paths or [],
build_directory=build_directory,
verbose=verbose,
with_cuda=with_cuda,
is_python_module=is_python_module,
is_standalone=is_standalone,
torch_exclude=torch_exclude,
)
elif verbose:
print(
"No modifications detected for re-loaded extension "
f"module {name}, skipping build step...",
file=sys.stderr,
)
finally:
baton.release()
else:
baton.wait()
if verbose:
print(f"Loading extension module {name}...", file=sys.stderr)
if is_standalone:
return _get_exec_path(name, build_directory)
return _import_module_from_library(
name, build_directory, is_python_module, torch_exclude
)
def _write_ninja_file_and_compile_objects(
sources: List[str],
objects,
cflags,
post_cflags,
cuda_cflags,
cuda_post_cflags,
cuda_dlink_post_cflags,
build_directory: str,
verbose: bool,
with_cuda: Optional[bool],
) -> None:
verify_ninja_availability()
compiler = get_cxx_compiler()
get_compiler_abi_compatibility_and_version(compiler)
if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
build_file_path = os.path.join(build_directory, "build.ninja")
if verbose:
print(f"Emitting ninja build file {build_file_path}...", file=sys.stderr)
_write_ninja_file(
path=build_file_path,
cflags=cflags,
post_cflags=post_cflags,
cuda_cflags=cuda_cflags,
cuda_post_cflags=cuda_post_cflags,
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
sources=sources,
objects=objects,
ldflags=None,
library_target=None,
with_cuda=with_cuda,
)
if verbose:
print("Compiling objects...", file=sys.stderr)
_run_ninja_build(
build_directory,
verbose,
# It would be better if we could tell users the name of the extension
# that failed to build but there isn't a good way to get it here.
error_prefix="Error compiling objects for extension",
)
def _write_ninja_file_and_build_library(
name,
sources: List[str],
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
build_directory: str,
verbose: bool,
with_cuda: Optional[bool],
is_python_module: bool,
is_standalone: bool = False,
torch_exclude: bool = False,
) -> None:
verify_ninja_availability()
compiler = get_cxx_compiler()
get_compiler_abi_compatibility_and_version(compiler, torch_exclude)
if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
extra_ldflags = _prepare_ldflags(
extra_ldflags or [], with_cuda, verbose, is_standalone, torch_exclude
)
build_file_path = os.path.join(build_directory, "build.ninja")
if verbose:
print(f"Emitting ninja build file {build_file_path}...", file=sys.stderr)
# NOTE: Emitting a new ninja build file does not cause re-compilation if
# the sources did not change, so it's ok to re-emit (and it's fast).
_write_ninja_file_to_build_library(
path=build_file_path,
name=name,
sources=sorted(set(sources)),
extra_cflags=extra_cflags or [],
extra_cuda_cflags=extra_cuda_cflags or [],
extra_ldflags=extra_ldflags or [],
extra_include_paths=extra_include_paths or [],
with_cuda=with_cuda,
is_python_module=is_python_module,
is_standalone=is_standalone,
torch_exclude=torch_exclude,
)
if verbose:
print(f"Building extension module {name}...", file=sys.stderr)
_run_ninja_build(
build_directory, verbose, error_prefix=f"Error building extension '{name}'"
)
def is_ninja_available():
"""Return ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise."""
try:
subprocess.check_output("ninja --version".split())
except Exception:
return False
else:
return True
def verify_ninja_availability():
"""Raise ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not available on the system, does nothing otherwise."""
if not is_ninja_available():
raise RuntimeError("Ninja is required to load C++ extensions")
def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone, torch_exclude):
extra_ldflags.append("-mcmodel=large")
extra_ldflags.append("-ffunction-sections")
extra_ldflags.append("-fdata-sections ")
extra_ldflags.append("-Wl,--gc-sections")
extra_ldflags.append("-Wl,--cref")
if not torch_exclude:
import torch
_TORCH_PATH = os.path.join(os.path.dirname(torch.__file__))
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, "lib")
extra_ldflags.append(f"-L{TORCH_LIB_PATH}")
extra_ldflags.append("-lc10")
if with_cuda:
extra_ldflags.append("-lc10_hip" if IS_HIP_EXTENSION else "-lc10_cuda")
extra_ldflags.append("-ltorch_cpu")
if with_cuda:
extra_ldflags.append("-ltorch_hip" if IS_HIP_EXTENSION else "-ltorch_cuda")
extra_ldflags.append("-ltorch")
if not is_standalone:
extra_ldflags.append("-ltorch_python")
if is_standalone:
extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")
if with_cuda and IS_HIP_EXTENSION:
if verbose:
print("Detected CUDA files, patching ldflags", file=sys.stderr)
extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
extra_ldflags.append("-lamdhip64")
return extra_ldflags
def _get_rocm_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
# If cflags is given, there may already be user-provided arch flags in it
# (from `extra_compile_args`)
if cflags is not None:
for flag in cflags:
if "amdgpu-target" in flag or "offload-arch" in flag:
return ["-fno-gpu-rdc"]
# Use same defaults as used for building PyTorch
# Allow env var to override, just like during initial cmake build.
_archs = os.environ.get("PYTORCH_ROCM_ARCH", None)
if not _archs:
import torch
archFlags = torch._C._cuda_getArchFlags()
if archFlags:
archs = archFlags.split()
else:
archs = []
else:
archs = _archs.replace(" ", ";").split(";")
flags = [f"--offload-arch={arch}" for arch in archs]
flags += ["-fno-gpu-rdc"]
return flags
def _get_num_workers(verbose: bool) -> Optional[int]:
max_jobs = os.environ.get("MAX_JOBS")
if max_jobs is not None and max_jobs.isdigit():
if int(max_jobs) > int(max(1, os.cpu_count() * 0.8)):
max_jobs = int(max(1, os.cpu_count() * 0.8))
if verbose:
print(
f"Using envvar MAX_JOBS ({max_jobs}) as the number of workers...",
file=sys.stderr,
)
else:
max_jobs = int(max(1, os.cpu_count() * 0.8))
print(
f"Using 0.8*cpu_cnt MAX_JOBS ({max_jobs}) as the number of workers...",
file=sys.stderr,
)
prebuild_thread_num = os.environ.get("PREBUILD_THREAD_NUM")
if prebuild_thread_num is not None:
max_jobs = int(max_jobs) / int(prebuild_thread_num)
return int(max_jobs)
def _run_ninja_build(build_directory: str, verbose: bool, error_prefix: str) -> None:
command = ["ninja", "-v"]
num_workers = _get_num_workers(verbose)
if num_workers is not None:
command.extend(["-j", str(num_workers)])
env = os.environ.copy()
try:
sys.stdout.flush()
sys.stderr.flush()
# Warning: don't pass stdout=None to subprocess.run to get output.
# subprocess.run assumes that sys.__stdout__ has not been modified and
# attempts to write to it by default. However, when we call _run_ninja_build
# from ahead-of-time cpp extensions, the following happens:
# 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__.
# https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110
# (it probably shouldn't do this)
# 2) subprocess.run (on POSIX, with no stdout override) relies on
# __stdout__ not being detached:
# https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214
# To work around this, we pass in the fileno directly and hope that
# it is valid.
stdout_fileno = 1
subprocess.run(
command,
stdout=stdout_fileno if verbose else subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=build_directory,
check=True,
env=env,
)
except subprocess.CalledProcessError as e:
# Python 2 and 3 compatible way of getting the error object.
_, error, _ = sys.exc_info()
# error.output contains the stdout and stderr of the build attempt.
message = error_prefix
# `error` is a CalledProcessError (which has an `output`) attribute, but
# mypy thinks it's Optional[BaseException] and doesn't narrow
if hasattr(error, "output") and error.output: # type: ignore[union-attr]
message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}" # type: ignore[union-attr]
raise RuntimeError(message) from e
def _get_exec_path(module_name, path):
return os.path.join(path, f"{module_name}{EXEC_EXT}")
def _import_module_from_library(module_name, path, is_python_module, torch_exclude):
filepath = os.path.join(path, f"{module_name}{LIB_EXT}")
if is_python_module:
return None
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
spec = importlib.util.spec_from_file_location(module_name, filepath)
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert isinstance(spec.loader, importlib.abc.Loader)
spec.loader.exec_module(module)
return module
else:
if not torch_exclude:
import torch
torch.ops.load_library(filepath)
def _write_ninja_file_to_build_library(
path,
name,
sources,
extra_cflags,
extra_cuda_cflags,
extra_ldflags,
extra_include_paths,
with_cuda,
is_python_module,
is_standalone,
torch_exclude,
) -> None:
extra_cflags = [flag.strip() for flag in extra_cflags]
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
extra_ldflags = [flag.strip() for flag in extra_ldflags]
extra_include_paths = [flag.strip() for flag in extra_include_paths]
# include_paths() gives us the location of torch/extension.h
system_includes = [] if torch_exclude else include_paths(with_cuda)
# FIXME: build python module excluded with torch, use `pybind11`
# But we can't use this now because all aiter op based on torch
# which means pybind11 related build flags must from torch now
common_cflags = []
if is_python_module:
import pybind11
extra_include_paths.append(pybind11.get_include())
common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
# sysconfig.get_path('include') gives us the location of Python.h
# Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS
# installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder
if is_python_module:
python_include_path = sysconfig.get_path("include", scheme="posix_prefix")
if python_include_path is not None:
system_includes.append(python_include_path)
# Turn into absolute paths so we can emit them into the ninja build
# file wherever it is.
user_includes = [os.path.abspath(file) for file in extra_include_paths]
if not torch_exclude:
common_cflags.append(f"-DTORCH_EXTENSION_NAME={name}")
common_cflags.append("-DTORCH_API_INCLUDE_EXTENSION_H")
# common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
# common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
# Windows does not understand `-isystem` and quotes flags later.
common_cflags += [f"-I{shlex.quote(include)}" for include in user_includes]
common_cflags += [f"-isystem {shlex.quote(include)}" for include in system_includes]
cflags = common_cflags + ["-fPIC", "-std=c++20"] + extra_cflags
if with_cuda and IS_HIP_EXTENSION:
cuda_flags = ["-DWITH_HIP"] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
cuda_flags += extra_cuda_cflags
cuda_flags += _get_rocm_arch_flags(cuda_flags)
def object_file_path(source_file: str) -> str:
# '/path/to/file.cpp' -> 'file'
file_name = os.path.splitext(os.path.basename(source_file))[0]
if _is_cuda_file(source_file) and with_cuda:
# Use a different object filename in case a C++ and CUDA file have
# the same filename but different extension (.cpp vs. .cu).
target = f"{file_name}.cuda.o"
else:
target = f"{file_name}.o"
return target
objects = [object_file_path(src) for src in sources]
ldflags = ([] if is_standalone else [SHARED_FLAG]) + extra_ldflags
ext = EXEC_EXT if is_standalone else LIB_EXT
library_target = f"{name}{ext}"
_write_ninja_file(
path=path,
cflags=cflags,
post_cflags=None,
cuda_cflags=cuda_flags,
cuda_post_cflags=None,
cuda_dlink_post_cflags=None,
sources=sources,
objects=objects,
ldflags=ldflags,
library_target=library_target,
with_cuda=with_cuda,
)
def _write_ninja_file(
path,
cflags,
post_cflags,
cuda_cflags,
cuda_post_cflags,
cuda_dlink_post_cflags,
sources,
objects,
ldflags,
library_target,
with_cuda,
) -> None:
r"""Write a ninja file that does the desired compiling and linking.
`path`: Where to write this file
`cflags`: list of flags to pass to $cxx. Can be None.
`post_cflags`: list of flags to append to the $cxx invocation. Can be None.
`cuda_cflags`: list of flags to pass to $nvcc. Can be None.
`cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None.
`sources`: list of paths to source files
`objects`: list of desired paths to objects, one per source.
`ldflags`: list of flags to pass to linker. Can be None.
`library_target`: Name of the output library. Can be None; in that case,
we do no linking.
`with_cuda`: If we should be compiling with CUDA.
"""
def sanitize_flags(flags):
if flags is None:
return []
else:
return [flag.strip() for flag in flags]
cflags = sanitize_flags(cflags)
post_cflags = sanitize_flags(post_cflags)
cuda_cflags = sanitize_flags(cuda_cflags)
cuda_post_cflags = sanitize_flags(cuda_post_cflags)
cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
ldflags = sanitize_flags(ldflags)
# Sanity checks...
assert len(sources) == len(objects)
assert len(sources) > 0
compiler = get_cxx_compiler()
# Version 1.3 is required for the `deps` directive.
config = ["ninja_required_version = 1.3"]
config.append(f"cxx = {compiler}")
if with_cuda or cuda_dlink_post_cflags:
nvcc = executable_path("hipcc")
config.append(f"nvcc = {nvcc}")
if IS_HIP_EXTENSION:
post_cflags = COMMON_HIP_FLAGS + post_cflags
flags = [f'cflags = {" ".join(cflags)}']
flags.append(f'post_cflags = {" ".join(post_cflags)}')
if with_cuda:
flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
# Turn into absolute paths so we can emit them into the ninja build
# file wherever it is.
sources = [os.path.abspath(file) for file in sources]
# See https://ninja-build.org/build.ninja.html for reference.
compile_rule = ["rule compile"]
compile_rule.append(
" command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags"
)
compile_rule.append(" depfile = $out.d")
compile_rule.append(" deps = gcc")
if with_cuda:
cuda_compile_rule = ["rule cuda_compile"]
nvcc_gendeps = ""
cuda_compile_rule.append(
f" command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags"
)
# Emit one build rule per source to enable incremental build.
build = []
for source_file, object_file in zip(sources, objects):
is_cuda_source = _is_cuda_file(source_file) and with_cuda
rule = "cuda_compile" if is_cuda_source else "compile"
source_file = source_file.replace(" ", "$ ")
object_file = object_file.replace(" ", "$ ")
build.append(f"build {object_file}: {rule} {source_file}")
flags.append(f'ldflags = {" ".join(ldflags)}')
if cuda_dlink_post_cflags:
devlink_out = os.path.join(os.path.dirname(objects[0]), "dlink.o")
devlink_rule = ["rule cuda_devlink"]
devlink_rule.append(" command = $nvcc $in -o $out $cuda_dlink_post_cflags")
devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
objects += [devlink_out]
else:
devlink_rule, devlink = [], []
if library_target is not None:
link_rule = ["rule link"]
link_rule.append(
" command = $cxx @$out.rsp $ldflags -o $out\n rspfile = $out.rsp\n rspfile_content = $in"
)
link = [f'build {library_target}: link {" ".join(objects)}']
default = [f"default {library_target}"]
else:
link_rule, link, default = [], [], []
# 'Blocks' should be separated by newlines, for visual benefit.
blocks = [config, flags, compile_rule]
if with_cuda:
blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined]
blocks += [devlink_rule, link_rule, build, devlink, link, default]
content = "\n\n".join("\n".join(b) for b in blocks)
# Ninja requires a new lines at the end of the .ninja file
content += "\n"
_maybe_write(path, content)
# def _join_cuda_home(*paths) -> str:
# """
# Join paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
# This is basically a lazy way of raising an error for missing $CUDA_HOME
# only once we need to get any CUDA-specific path.
# """
# if CUDA_HOME is None:
# raise OSError('CUDA_HOME environment variable is not set. '
# 'Please set it to your CUDA install root.')
# return os.path.join(CUDA_HOME, *paths)
def _is_cuda_file(path: str) -> bool:
return True
valid_ext = [".cu", ".cuh"]
if IS_HIP_EXTENSION:
valid_ext.append(".hip")
return os.path.splitext(path)[1] in valid_ext
# SPDX-License-Identifier: MIT
# mypy: allow-untyped-defs
import os
import time
class FileBaton:
"""A primitive, file-based synchronization utility."""
def __init__(self, lock_file_path, wait_seconds=0.1):
"""
Create a new :class:`FileBaton`.
Args:
lock_file_path: The path to the file used for locking.
wait_seconds: The seconds to periodically sleep (spin) when
calling ``wait()``.
"""
self.lock_file_path = lock_file_path
self.wait_seconds = wait_seconds
self.fd = None
def try_acquire(self):
"""
Try to atomically create a file under exclusive access.
Returns:
True if the file could be created, else False.
"""
try:
self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL)
return True
except FileExistsError:
return False
def wait(self):
"""
Periodically sleeps for a certain amount until the baton is released.
The amount of time slept depends on the ``wait_seconds`` parameter
passed to the constructor.
"""
while os.path.exists(self.lock_file_path):
time.sleep(self.wait_seconds)
def release(self):
"""Release the baton and removes its file."""
if self.fd is not None:
os.close(self.fd)
os.remove(self.lock_file_path)
# SPDX-License-Identifier: MIT
\ No newline at end of file
# SPDX-License-Identifier: MIT
"""Constants for annotations in the mapping.
The constants defined here are used to annotate the mapping tuples in cuda_to_hip_mappings.py.
They are based on
https://github.com/ROCm/HIPIFY/blob/master/src/Statistics.h
and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsupported
mapping.
"""
CONV_VERSION = (0,)
CONV_INIT = 1
CONV_DEVICE = 2
CONV_MEM = 3
CONV_KERN = 4
CONV_COORD_FUNC = 5
CONV_MATH_FUNC = 6
CONV_DEVICE_FUNC = 7
CONV_SPECIAL_FUNC = 8
CONV_STREAM = 9
CONV_EVENT = 10
CONV_OCCUPANCY = 11
CONV_CONTEXT = 12
CONV_PEER = 13
CONV_MODULE = 14
CONV_CACHE = 15
CONV_EXEC = 16
CONV_ERROR = 17
CONV_DEF = 18
CONV_TEX = 19
CONV_GL = 20
CONV_GRAPHICS = 21
CONV_SURFACE = 22
CONV_JIT = 23
CONV_D3D9 = 24
CONV_D3D10 = 25
CONV_D3D11 = 26
CONV_VDPAU = 27
CONV_EGL = 28
CONV_THREAD = 29
CONV_OTHER = 30
CONV_INCLUDE = 31
CONV_INCLUDE_CUDA_MAIN_H = 32
CONV_TYPE = 33
CONV_LITERAL = 34
CONV_NUMERIC_LITERAL = 35
CONV_LAST = 36
API_DRIVER = 37
API_RUNTIME = 38
API_BLAS = 39
API_SPECIAL = 40
API_RAND = 41
API_LAST = 42
API_FFT = 43
API_RTC = 44
API_ROCTX = 45
HIP_UNSUPPORTED = 46
API_PYTORCH = 1337
API_CAFFE2 = 1338
API_C10 = 1339
API_ROCMSMI = 1340
This source diff could not be displayed because it is too large. You can view the blob instead.
# SPDX-License-Identifier: MIT
#!/usr/bin/env python3
# mypy: allow-untyped-defs
"""The Python Hipify script.
##
# Facebook Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
"""
import argparse
import fnmatch
import re
import shutil
import sys
import os
from . import constants
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
from typing import Dict, List, Iterator, Optional
from collections.abc import Mapping, Iterable
from enum import Enum
class CurrentState(Enum):
INITIALIZED = 1
DONE = 2
class HipifyResult:
def __init__(self, current_state, hipified_path):
self.current_state = current_state
self.hipified_path = hipified_path
self.status = ""
def __str__(self):
return f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}"
HipifyFinalResult = Dict[str, HipifyResult]
HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
# Hardcode the PyTorch template map
"""This dictionary provides the mapping from PyTorch kernel template types
to their actual types."""
PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
__all__ = [
"InputError",
"openf",
"bcolors",
"GeneratedFileCleaner",
"match_extensions",
"matched_files_iter",
"preprocess_file_and_save_result",
"compute_stats",
"add_dim3",
"processKernelLaunches",
"find_closure_group",
"find_bracket_group",
"find_parentheses_group",
"replace_math_functions",
"hip_header_magic",
"replace_extern_shared",
"get_hip_file_path",
"is_out_of_place",
"is_pytorch_file",
"is_cusparse_file",
"is_special_file",
"is_caffe2_gpu_file",
"is_caffe2_gpu_file",
"Trie",
"preprocessor",
"file_specific_replacement",
"file_add_header",
"fix_static_global_kernels",
"extract_arguments",
"str2bool",
"CurrentState",
"HipifyResult",
"hipify",
]
class InputError(Exception):
# Exception raised for errors in the input.
def __init__(self, message):
super().__init__(message)
self.message = message
def __str__(self):
return f"Input error: {self.message}"
def openf(filename, mode):
return open(filename, mode, errors="ignore")
# Color coding for printing
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
# To the programmer, the output of hipify most likely are intermediates.
# This class allows users of hipify to ask for a cleanup by running the
# hipify and compilation in a with instantiating this context manager class
# with keep_intermediates=False.
# The main usecase is the cpp_extensions, specifically the load method.
# It is a good idea to keep intermediates (in case of errors or to
# not recompile unchanged files), but in cases where you don't want to
# keep them (e.g. in the CI), this can be used to remove files.
class GeneratedFileCleaner:
"""Context Manager to clean up generated files"""
def __init__(self, keep_intermediates=False):
self.keep_intermediates = keep_intermediates
self.files_to_clean = set()
self.dirs_to_clean = []
def __enter__(self):
return self
def open(self, fn, *args, **kwargs):
if not os.path.exists(fn):
self.files_to_clean.add(os.path.abspath(fn))
return open(fn, *args, **kwargs)
def makedirs(self, dn, exist_ok=False):
parent, n = os.path.split(dn)
if not n:
parent, n = os.path.split(parent)
if parent and n and not os.path.exists(parent):
self.makedirs(parent, exist_ok=True)
if not os.path.isdir(dn) or not exist_ok:
os.mkdir(dn)
self.dirs_to_clean.append(os.path.abspath(dn))
def __exit__(self, type, value, traceback):
if not self.keep_intermediates:
for f in self.files_to_clean:
os.unlink(f)
for d in self.dirs_to_clean[::-1]:
os.rmdir(d)
def match_extensions(filename: str, extensions: Iterable) -> bool:
"""Helper method to see if filename ends with certain extension"""
return any(filename.endswith(e) for e in extensions)
def _fnmatch(filepath, patterns):
return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
def matched_files_iter(
root_path: str,
includes: Iterable = (),
ignores: Iterable = (),
extensions: Iterable = (),
out_of_place_only: bool = False,
is_pytorch_extension: bool = False,
) -> Iterator[str]:
exact_matches = set(includes)
# This is a very rough heuristic; really, we want to avoid scanning
# any file which is not checked into source control, but this script
# needs to work even if you're in a Git or Hg checkout, so easier to
# just block the biggest time sinks that won't matter in the
# end.
for abs_dirpath, dirs, filenames in os.walk(root_path, topdown=True):
rel_dirpath = os.path.relpath(abs_dirpath, root_path)
if rel_dirpath == ".":
# Blah blah blah O(n) blah blah
if ".git" in dirs:
dirs.remove(".git")
if "build" in dirs:
dirs.remove("build")
if "third_party" in dirs:
dirs.remove("third_party")
dirs.append("third_party/nvfuser")
for filename in filenames:
filepath = os.path.join(abs_dirpath, filename)
rel_filepath = os.path.join(rel_dirpath, filename)
# We respect extensions, UNLESS you wrote the entire
# filename verbatim, in which case we always accept it
if (
_fnmatch(filepath, includes)
and (not _fnmatch(filepath, ignores))
and (
match_extensions(filepath, extensions) or filepath in exact_matches
)
):
if (
not is_pytorch_extension
): # for pytorch extensions, consider all files
if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(
rel_filepath
):
continue
if out_of_place_only and not is_out_of_place(rel_filepath):
continue
yield filepath
def preprocess_file_and_save_result(
output_directory: str,
filepath: str,
all_files: Iterable,
header_include_dirs: Iterable,
stats: Dict[str, List],
hip_clang_launch: bool,
is_pytorch_extension: bool,
clean_ctx: GeneratedFileCleaner,
show_progress: bool,
) -> None:
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
hipify_result = HipifyResult(
current_state=CurrentState.INITIALIZED, hipified_path=fin_path
)
HIPIFY_FINAL_RESULT[fin_path] = hipify_result
result = preprocessor(
output_directory,
filepath,
all_files,
header_include_dirs,
stats,
hip_clang_launch,
is_pytorch_extension,
clean_ctx,
show_progress,
)
# Show what happened
if show_progress and "ignored" not in result.status:
print(fin_path, "->", result.hipified_path, result.status, flush=True)
HIPIFY_FINAL_RESULT[fin_path] = result
def compute_stats(stats):
unsupported_calls = {
cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]
}
# Print the number of unsupported calls
print(
f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}"
)
# Print the list of unsupported calls
print(", ".join(unsupported_calls))
# Print the number of kernel launches
print(
f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}"
)
def add_dim3(kernel_string, cuda_kernel):
"""adds dim3() to the second and third arguments in the kernel launch"""
count = 0
closure = 0
kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
arg_locs: List[Dict[str, int]] = [{} for _ in range(2)]
arg_locs[count]["start"] = 0
for ind, c in enumerate(kernel_string):
if count > 1:
break
if c == "(":
closure += 1
elif c == ")":
closure -= 1
if (c == "," or ind == len(kernel_string) - 1) and closure == 0:
arg_locs[count]["end"] = ind + (c != ",")
count += 1
if count < 2:
arg_locs[count]["start"] = ind + 1
first_arg_raw = kernel_string[arg_locs[0]["start"] : arg_locs[0]["end"] + 1]
second_arg_raw = kernel_string[arg_locs[1]["start"] : arg_locs[1]["end"]]
first_arg_clean = (
kernel_string[arg_locs[0]["start"] : arg_locs[0]["end"]]
.replace("\n", "")
.strip(" ")
)
second_arg_clean = (
kernel_string[arg_locs[1]["start"] : arg_locs[1]["end"]]
.replace("\n", "")
.strip(" ")
)
first_arg_dim3 = f"dim3({first_arg_clean})"
second_arg_dim3 = f"dim3({second_arg_clean})"
first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
cuda_kernel = cuda_kernel.replace(
first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3
)
return cuda_kernel
RE_KERNEL_LAUNCH = re.compile(r"([ ]+)(detail?)::[ ]+\\\n[ ]+")
def processKernelLaunches(string, stats):
"""Replace the CUDA style Kernel launches with the HIP style kernel launches."""
# Concat the namespace with the kernel names. (Find cleaner way of doing this later).
string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string)
def grab_method_and_template(in_kernel):
# The positions for relevant kernel components.
pos = {
"kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
"kernel_name": {"start": -1, "end": -1},
"template": {"start": -1, "end": -1},
}
# Count for balancing template
count = {"<>": 0}
# Status for whether we are parsing a certain item.
START = 0
AT_TEMPLATE = 1
AFTER_TEMPLATE = 2
AT_KERNEL_NAME = 3
status = START
# Parse the string character by character
for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
char = string[i]
# Handle Templating Arguments
if status in (START, AT_TEMPLATE):
if char == ">":
if status == START:
status = AT_TEMPLATE
pos["template"]["end"] = i
count["<>"] += 1
if char == "<":
count["<>"] -= 1
if count["<>"] == 0 and (status == AT_TEMPLATE):
pos["template"]["start"] = i
status = AFTER_TEMPLATE
# Handle Kernel Name
if status != AT_TEMPLATE:
if string[i].isalnum() or string[i] in {"(", ")", "_", ":", "#"}:
if status != AT_KERNEL_NAME:
status = AT_KERNEL_NAME
pos["kernel_name"]["end"] = i
# Case: Kernel name starts the string.
if i == 0:
pos["kernel_name"]["start"] = 0
# Finished
return [
(pos["kernel_name"]),
(pos["template"]),
(pos["kernel_launch"]),
]
else:
# Potential ending point if we're already traversing a kernel's name.
if status == AT_KERNEL_NAME:
pos["kernel_name"]["start"] = i
# Finished
return [
(pos["kernel_name"]),
(pos["template"]),
(pos["kernel_launch"]),
]
def find_kernel_bounds(string):
"""Finds the starting and ending points for all kernel launches in the string."""
kernel_end = 0
kernel_positions = []
# Continue until we cannot find any more kernels anymore.
while string.find("<<<", kernel_end) != -1:
# Get kernel starting position (starting from the previous ending point)
kernel_start = string.find("<<<", kernel_end)
# Get kernel ending position (adjust end point past the >>>)
kernel_end = string.find(">>>", kernel_start) + 3
if kernel_end <= 0:
raise InputError("no kernel end found")
# Add to list of traversed kernels
kernel_positions.append(
{
"start": kernel_start,
"end": kernel_end,
"group": string[kernel_start:kernel_end],
}
)
return kernel_positions
# Replace comments and string literals from the code so that find_kernel_bounds does not
# wrongly capture kernels in comments and string literals.
# This function replaces them with "x" to keep positions.
def mask_comments(string):
in_comment = ""
prev_c = ""
new_string = ""
for c in string:
if in_comment == "":
# Outside comments
if c == "/" and prev_c == "/":
in_comment = "//"
elif c == "*" and prev_c == "/":
in_comment = "/*"
elif c == '"' and prev_c != "\\" and prev_c != "'":
in_comment = '"'
elif in_comment == "//":
# In // xxx
if c == "\r" or c == "\n":
in_comment = ""
elif in_comment == "/*":
# In /* xxx */
if c == "/" and prev_c == "*":
in_comment = ""
elif in_comment == '"':
# In ""
if c == '"' and prev_c != "\\":
in_comment = ""
prev_c = c
if in_comment == "":
new_string += c
else:
new_string += "x"
return new_string
# Grab positional ranges of all kernel launches
get_kernel_positions = list(find_kernel_bounds(mask_comments(string)))
output_string = string
# Replace each CUDA kernel with a HIP kernel.
for kernel in get_kernel_positions:
# Get kernel components
params = grab_method_and_template(kernel)
# Find parenthesis after kernel launch
parenthesis = string.find("(", kernel["end"])
# Extract cuda kernel
cuda_kernel = string[params[0]["start"] : parenthesis + 1]
kernel_string = string[kernel["start"] : kernel["end"]]
end_param_index = 0 if params[1]["end"] == -1 else 1
kernel_name_with_template = string[
params[0]["start"] : params[end_param_index]["end"] + 1
]
cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
# Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
num_klp = len(
extract_arguments(
0, kernel["group"].replace("<<<", "(").replace(">>>", ")")
)
)
hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
">>>", ", 0" * (4 - num_klp) + ">>>"
).replace("<<<", ", ").replace(">>>", ", ").replace(
kernel_name_with_template, "(" + kernel_name_with_template + ")"
)
# Replace cuda kernel with hip kernel
output_string = output_string.replace(cuda_kernel, hip_kernel)
# Update the statistics
stats["kernel_launches"].append(hip_kernel)
return output_string
def find_closure_group(input_string, start, group):
"""Generalization for finding a balancing closure group
if group = ["(", ")"], then finds the first balanced parentheses.
if group = ["{", "}"], then finds the first balanced bracket.
Given an input string, a starting position in the input string, and the group type,
find_closure_group returns the positions of group[0] and group[1] as a tuple.
Example:
>>> find_closure_group("(hi)", 0, ["(", ")"])
(0, 3)
"""
inside_parenthesis = False
parens = 0
pos = start
p_start, p_end = -1, -1
while pos < len(input_string):
if input_string[pos] == group[0]:
if inside_parenthesis is False:
inside_parenthesis = True
parens = 1
p_start = pos
else:
parens += 1
elif input_string[pos] == group[1] and inside_parenthesis:
parens -= 1
if parens == 0:
p_end = pos
return p_start, p_end
pos += 1
return None, None
def find_bracket_group(input_string, start):
"""Finds the first balanced parantheses."""
return find_closure_group(input_string, start, group=["{", "}"])
def find_parentheses_group(input_string, start):
"""Finds the first balanced bracket."""
return find_closure_group(input_string, start, group=["(", ")"])
RE_ASSERT = re.compile(r"\bassert[ ]*\(")
def replace_math_functions(input_string):
"""FIXME: Temporarily replace std:: invocations of math functions
with non-std:: versions to prevent linker errors NOTE: This
can lead to correctness issues when running tests, since the
correct version of the math function (exp/expf) might not get
called. Plan is to remove this function once HIP supports
std:: math function calls inside device code
"""
output_string = input_string
for func in MATH_TRANSPILATIONS:
output_string = output_string.replace(
rf"{func}(", f"{MATH_TRANSPILATIONS[func]}("
)
return output_string
RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()")
def hip_header_magic(input_string):
"""If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
then automatically add an #include to match the "magic" includes provided by NVCC.
TODO:
Update logic to ignore cases where the cuda_runtime.h is included by another file.
"""
# Copy the input.
output_string = input_string
# Check if one of the following headers is already included.
headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"]
if any(re.search(rf'#include ("{ext}"|<{ext}>)', output_string) for ext in headers):
return output_string
# Rough logic to detect if we're inside device code
hasDeviceLogic: int
hasDeviceLogic = "hipLaunchKernelGGL" in output_string
hasDeviceLogic += "__global__" in output_string
hasDeviceLogic += "__shared__" in output_string
hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None
# If device logic found, provide the necessary header.
if hasDeviceLogic:
output_string = '#include "hip/hip_runtime.h"\n' + input_string
return output_string
RE_EXTERN_SHARED = re.compile(
r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;"
)
def replace_extern_shared(input_string):
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
Example:
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""
output_string = input_string
output_string = RE_EXTERN_SHARED.sub(
lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})",
output_string,
)
return output_string
def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
"""
Returns the new name of the hipified file
"""
# At the moment, some PyTorch source files are HIPified in place. The predicate
# is_out_of_place tells us if this is the case or not.
assert not os.path.isabs(rel_filepath)
if not is_pytorch_extension and not is_out_of_place(rel_filepath):
return rel_filepath
dirpath, filename = os.path.split(rel_filepath)
root, ext = os.path.splitext(filename)
# Here's the plan:
#
# In general, we need to disambiguate the HIPified filename so that
# it gets a different name from the original filename, so
# that we don't overwrite the original file
#
# There's a lot of different naming conventions across PyTorch
# and Caffe2, but the general recipe is to convert occurrences
# of cuda/gpu to hip, and add hip if there are no occurrences
# of cuda/gpu anywhere.
#
# Concretely, we do the following:
#
# - If there is a directory component named "cuda", replace
# it with "hip", AND
#
# - If the file name contains "CUDA", replace it with "HIP", AND
#
# - ALWAYS replace '.cu' with '.hip', because those files
# contain CUDA kernels that needs to be hipified and processed with
# hip compiler
#
# - If we are not hipifying a PyTorch extension, and the parent
# directory name did not change as a result of the above
# transformations, insert "hip" in the file path
# as the direct parent folder of the file
#
# - If we are hipifying a PyTorch extension, and the parent directory
# name as well as the filename (incl. extension) did not change as
# a result of the above transformations, insert "_hip" in the filename
#
# This isn't set in stone; we might adjust this to support other
# naming conventions.
if ext == ".cu":
ext = ".hip"
orig_filename = filename
orig_dirpath = dirpath
dirpath = dirpath.replace("cuda", "hip")
dirpath = dirpath.replace("CUDA", "HIP")
dirpath = dirpath.replace("THC", "THH")
root = root.replace("cuda", "hip")
root = root.replace("CUDA", "HIP")
# Special case to handle caffe2/core/THCCachingAllocator
if dirpath != "caffe2/core":
root = root.replace("THC", "THH")
if not is_pytorch_extension and dirpath == orig_dirpath:
dirpath = os.path.join(dirpath, "hip")
if (
is_pytorch_extension
and dirpath == orig_dirpath
and (root + ext) == orig_filename
):
root = root + "_hip"
return os.path.join(dirpath, root + ext)
def is_out_of_place(rel_filepath):
assert not os.path.isabs(rel_filepath)
if rel_filepath.startswith("torch/"):
return False
if rel_filepath.startswith("third_party/nvfuser/"):
return False
if rel_filepath.startswith("tools/autograd/templates/"):
return False
return True
# Keep this synchronized with includes/ignores in build_hygon.py
def is_pytorch_file(rel_filepath):
assert not os.path.isabs(rel_filepath)
if rel_filepath.startswith("aten/"):
if rel_filepath.startswith("aten/src/ATen/core/"):
return False
return True
if rel_filepath.startswith("torch/"):
return True
if rel_filepath.startswith("third_party/nvfuser/"):
return True
if rel_filepath.startswith("tools/autograd/templates/"):
return True
return False
def is_cusparse_file(rel_filepath):
if is_pytorch_file(rel_filepath):
return "sparse" in rel_filepath.lower()
return False
def is_special_file(rel_filepath):
if is_pytorch_file(rel_filepath):
if "sparse" in rel_filepath.lower():
return True
elif "linalg" in rel_filepath.lower():
if "batchlinearalgebralibblas" in rel_filepath.lower():
return False # don't use "special" mappings for this specific linalg cublas file
return True
return False
def is_caffe2_gpu_file(rel_filepath):
assert not os.path.isabs(rel_filepath)
if rel_filepath.startswith("c10/cuda"):
return True
filename = os.path.basename(rel_filepath)
_, ext = os.path.splitext(filename)
return ("gpu" in filename or ext in [".cu", ".cuh"]) and ("cudnn" not in filename)
class TrieNode:
"""A Trie node whose children are represented as a directory of char: TrieNode.
A special char '' represents end of word
"""
def __init__(self):
self.children = {}
class Trie:
"""Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
The corresponding Regex should match much faster than a simple Regex union."""
def __init__(self):
"""Initialize the trie with an empty root node."""
self.root = TrieNode()
def add(self, word):
"""Add a word to the Trie."""
node = self.root
for char in word:
node.children.setdefault(char, TrieNode())
node = node.children[char]
node.children[""] = True # Mark the end of the word
def dump(self):
"""Return the root node of Trie."""
return self.root
def quote(self, char):
"""Escape a char for regex."""
return re.escape(char)
def search(self, word):
"""Search whether word is present in the Trie.
Returns True if yes, else return False"""
node = self.root
for char in word:
if char in node.children:
node = node.children[char]
else:
return False
# make sure to check the end-of-word marker present
return "" in node.children
def _pattern(self, root):
"""Convert a Trie into a regular expression pattern"""
node = root
if "" in node.children and len(node.children.keys()) == 1:
return None
alt = [] # store alternative patterns
cc = [] # store char to char classes
q = 0 # for node representing the end of word
for char in sorted(node.children.keys()):
if isinstance(node.children[char], TrieNode):
try:
recurse = self._pattern(node.children[char])
alt.append(self.quote(char) + recurse)
except Exception:
cc.append(self.quote(char))
else:
q = 1
cconly = not len(alt) > 0
if len(cc) > 0:
if len(cc) == 1:
alt.append(cc[0])
else:
alt.append("[" + "".join(cc) + "]")
if len(alt) == 1:
result = alt[0]
else:
result = "(?:" + "|".join(alt) + ")"
if q:
if cconly:
result += "?"
else:
result = f"(?:{result})?"
return result
def pattern(self):
"""Export the Trie to a regex pattern."""
return self._pattern(self.root)
def export_to_regex(self):
"""Export the Trie to a regex pattern."""
return self._pattern(self.root)
CAFFE2_TRIE = Trie()
CAFFE2_MAP = {}
PYTORCH_TRIE = Trie()
PYTORCH_MAP: Dict[str, object] = {}
# In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
# The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
# Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
# In the case of SPARSE, we must use the hip types for complex instead of the roc types,
# but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
# Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling.
PYTORCH_SPECIAL_MAP = {}
for mapping in CUDA_TO_HIP_MAPPINGS:
assert isinstance(mapping, Mapping)
for src, value in mapping.items():
dst = value[0]
meta_data = value[1:]
if constants.API_CAFFE2 not in meta_data:
PYTORCH_TRIE.add(src)
# if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL
# do not overwrite PYTORCH_MAP, store dst separately
if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""):
PYTORCH_SPECIAL_MAP[src] = dst
else:
PYTORCH_MAP[src] = dst
if (
constants.API_PYTORCH not in meta_data
and constants.API_SPECIAL not in meta_data
):
CAFFE2_TRIE.add(src)
CAFFE2_MAP[src] = dst
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex())
RE_PYTORCH_PREPROCESSOR = re.compile(
rf"(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)"
)
RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
RE_ANGLE_HEADER = re.compile(r"#include <([^>]+)>")
RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
RE_CU_SUFFIX = re.compile(r"\.cu\b") # be careful not to pick up .cuh
"""
Returns a HipifyResult object with the following details:
"hipified_path" : absolute path of hipified source file
"status" : "ok" if hipified file was written out
"skipped" if an identical hipified file already existed or hipified file couldn't be written out
"ignored" if the source file was a hipified file itself or not meant to be hipified
"current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified
CurrentState.DONE if source file is done with hipification process
"""
def preprocessor(
output_directory: str,
filepath: str,
all_files: Iterable,
header_include_dirs: Iterable,
stats: Dict[str, List],
hip_clang_launch: bool,
is_pytorch_extension: bool,
clean_ctx: GeneratedFileCleaner,
show_progress: bool,
) -> HipifyResult:
"""Executes the CUDA -> HIP conversion on the specified file."""
fin_path = os.path.abspath(os.path.join(output_directory, filepath))
hipify_result = HIPIFY_FINAL_RESULT[fin_path]
if filepath not in all_files:
hipify_result.hipified_path = None
hipify_result.status = "[ignored, not to be hipified]"
hipify_result.current_state = CurrentState.DONE
return hipify_result
rel_filepath = os.path.relpath(filepath, output_directory)
with open(fin_path, encoding="utf-8") as fin:
if fin.readline() == HIPIFY_C_BREADCRUMB:
hipify_result.hipified_path = None
hipify_result.status = "[ignored, input is hipified output]"
hipify_result.current_state = CurrentState.DONE
return hipify_result
fin.seek(0)
output_source = fin.read()
orig_output_source = output_source
# get_hip_file_path needs a relative path to work correctly
fout_path = os.path.abspath(
os.path.join(
output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)
)
)
if not os.path.exists(os.path.dirname(fout_path)):
clean_ctx.makedirs(os.path.dirname(fout_path))
# unsupported_calls statistics reporting is broken atm
def pt_repl(m):
return PYTORCH_MAP[m.group(0)]
def pt_special_repl(m):
# checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings
return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m))
if is_pytorch_extension:
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
if is_special_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source)
elif is_pytorch_file(rel_filepath):
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
else:
def c2_repl(m):
return CAFFE2_MAP[m.group(0)]
output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
# Header rewrites
def mk_repl(templ, include_current_dir=True):
def repl(m):
f = m.group(1)
dirpath, filename = os.path.split(f)
if f.startswith(
(
"ATen/cuda",
"ATen/native/cuda",
"ATen/native/nested/cuda",
"ATen/native/quantized/cuda",
"ATen/native/sparse/cuda",
"ATen/native/transformers/cuda",
"THC/",
)
) or (f.startswith("THC") and not f.startswith("THCP")):
return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
# if filename is one of the files being hipified for this extension
if is_pytorch_extension and any(s.endswith(filename) for s in all_files):
header_dir = None
header_filepath = None
# If include_current_dir True, look first in same dir as the including source file
if include_current_dir:
header_dir_to_check = os.path.dirname(fin_path)
header_path_to_check = os.path.abspath(
os.path.join(header_dir_to_check, f)
)
if os.path.exists(header_path_to_check):
header_dir = header_dir_to_check
header_filepath = header_path_to_check
# If not found, look in include dirs one by one and first match wins
if header_filepath is None:
for header_include_dir in header_include_dirs:
header_dir_to_check = os.path.join(
output_directory, header_include_dir
)
header_path_to_check = os.path.abspath(
os.path.join(header_dir_to_check, f)
)
if os.path.exists(header_path_to_check):
header_dir = header_dir_to_check
header_filepath = header_path_to_check
# If header file not found, keep as is
if header_filepath is None:
return m.group(0)
# Hipify header file first if needed
if header_filepath not in HIPIFY_FINAL_RESULT:
preprocess_file_and_save_result(
output_directory,
header_filepath,
all_files,
header_include_dirs,
stats,
hip_clang_launch,
is_pytorch_extension,
clean_ctx,
show_progress,
)
elif header_filepath in HIPIFY_FINAL_RESULT:
header_result = HIPIFY_FINAL_RESULT[header_filepath]
if header_result.current_state == CurrentState.INITIALIZED:
# get_hip_file_path needs a relative path to work correctly
header_rel_path = os.path.relpath(
header_filepath, output_directory
)
header_fout_path = os.path.abspath(
os.path.join(
output_directory,
get_hip_file_path(
header_rel_path, is_pytorch_extension
),
)
)
header_result.hipified_path = header_fout_path
HIPIFY_FINAL_RESULT[header_filepath] = header_result
return templ.format(
os.path.relpath(
(
header_fout_path
if header_fout_path is not None
else header_filepath
),
header_dir,
)
)
hipified_header_filepath = HIPIFY_FINAL_RESULT[
header_filepath
].hipified_path
return templ.format(
os.path.relpath(
(
hipified_header_filepath
if hipified_header_filepath is not None
else header_filepath
),
header_dir,
)
)
return m.group(0)
return repl
output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
output_source = RE_ANGLE_HEADER.sub(mk_repl("#include <{0}>", False), output_source)
output_source = RE_THC_GENERIC_FILE.sub(
mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source
)
# CMakeLists.txt rewrites
if filepath.endswith("CMakeLists.txt"):
output_source = output_source.replace("CUDA", "HIP")
output_source = output_source.replace("THC", "THH")
output_source = RE_CU_SUFFIX.sub(".hip", output_source)
# Perform Kernel Launch Replacements
if not hip_clang_launch:
output_source = processKernelLaunches(output_source, stats)
# Replace std:: with non-std:: versions
if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
output_source = replace_math_functions(output_source)
# Include header if device code is contained.
output_source = hip_header_magic(output_source)
# Replace the extern __shared__
# NOTE: No longer needed after transition from hcc to hipclang.
# output_source = replace_extern_shared(output_source)
# Don't write out identical hipified files for extensions if dirpath has not changed
if (
is_pytorch_extension
and orig_output_source == output_source
and os.path.dirname(fin_path) == os.path.dirname(fout_path)
):
hipify_result.hipified_path = fin_path
hipify_result.status = "[skipped, no changes]"
hipify_result.current_state = CurrentState.DONE
return hipify_result
# Add hipify breadcrumb for C-style files to avoid re-hipification
if fin_path != fout_path and match_extensions(
fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")
):
output_source = HIPIFY_C_BREADCRUMB + output_source
do_write = True
if os.path.exists(fout_path):
with open(fout_path, encoding="utf-8") as fout_old:
do_write = fout_old.read() != output_source
if do_write:
try:
with clean_ctx.open(fout_path, "w", encoding="utf-8") as fout:
fout.write(output_source)
hipify_result.hipified_path = fout_path
hipify_result.status = "[ok]"
hipify_result.current_state = CurrentState.DONE
return hipify_result
except PermissionError as e:
print(
f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}',
file=sys.stderr,
)
hipify_result.hipified_path = fin_path
hipify_result.status = "[skipped, no permissions]"
hipify_result.current_state = CurrentState.DONE
return hipify_result
else:
hipify_result.hipified_path = fout_path
hipify_result.status = "[skipped, already hipified]"
hipify_result.current_state = CurrentState.DONE
return hipify_result
def file_specific_replacement(filepath, search_string, replace_string, strict=False):
with openf(filepath, "r+") as f:
contents = f.read()
if strict:
contents = re.sub(
rf"\b({re.escape(search_string)})\b", lambda x: replace_string, contents
)
else:
contents = contents.replace(search_string, replace_string)
f.seek(0)
f.write(contents)
f.truncate()
def file_add_header(filepath, header):
with openf(filepath, "r+") as f:
contents = f.read()
if header[0] != "<" and header[-1] != ">":
header = f'"{header}"'
contents = (f"#include {header} \n") + contents
f.seek(0)
f.write(contents)
f.truncate()
def fix_static_global_kernels(in_txt):
"""Static global kernels in HIP results in a compilation error."""
in_txt = in_txt.replace(" __global__ static", "__global__")
return in_txt
RE_INCLUDE = re.compile(r"#include .*\n")
def extract_arguments(start, string):
"""Return the list of arguments in the upcoming function parameter closure.
Example:
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
arguments (output):
'[{'start': 1, 'end': 7},
{'start': 8, 'end': 16},
{'start': 17, 'end': 19},
{'start': 20, 'end': 53}]'
"""
arguments = []
closures = {"<": 0, "(": 0}
current_position = start
argument_start_pos = current_position + 1
# Search for final parenthesis
while current_position < len(string):
if string[current_position] == "(":
closures["("] += 1
elif string[current_position] == ")":
closures["("] -= 1
elif string[current_position] == "<":
closures["<"] += 1
elif (
string[current_position] == ">"
and string[current_position - 1] != "-"
and closures["<"] > 0
):
closures["<"] -= 1
# Finished all arguments
if closures["("] == 0 and closures["<"] == 0:
# Add final argument
arguments.append({"start": argument_start_pos, "end": current_position})
break
# Finished current argument
if (
closures["("] == 1
and closures["<"] == 0
and string[current_position] == ","
):
arguments.append({"start": argument_start_pos, "end": current_position})
argument_start_pos = current_position + 1
current_position += 1
return arguments
def str2bool(v):
"""ArgumentParser doesn't support type=bool. Thus, this helper method will convert
from possible string types to True / False."""
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def hipify(
project_directory: str,
show_detailed: bool = False,
extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
header_extensions: Iterable = (".cuh", ".h", ".hpp"),
output_directory: str = "",
header_include_dirs: Iterable = (),
includes: Iterable = ("*",),
extra_files: Iterable = (),
out_of_place_only: bool = False,
ignores: Iterable = (),
show_progress: bool = True,
hip_clang_launch: bool = False,
is_pytorch_extension: bool = False,
hipify_extra_files_only: bool = False,
clean_ctx: Optional[GeneratedFileCleaner] = None,
) -> HipifyFinalResult:
if project_directory == "":
project_directory = os.getcwd()
# Verify the project directory exists.
if not os.path.exists(project_directory):
print("The project folder specified does not exist.")
sys.exit(1)
# If no output directory, provide a default one.
if not output_directory:
project_directory.rstrip("/")
output_directory = project_directory + "_hygon"
if project_directory != output_directory:
includes = [
include.replace(project_directory, output_directory) for include in includes
]
ignores = [
ignore.replace(project_directory, output_directory) for ignore in ignores
]
# Copy from project directory to output directory if not done already.
if not os.path.exists(output_directory):
shutil.copytree(project_directory, output_directory)
all_files = list(
matched_files_iter(
output_directory,
includes=includes,
ignores=ignores,
extensions=extensions,
out_of_place_only=out_of_place_only,
is_pytorch_extension=is_pytorch_extension,
)
)
all_files_set = set(all_files)
for f in extra_files:
if not os.path.isabs(f):
f = os.path.join(output_directory, f)
if f not in all_files_set:
all_files.append(f)
# List all files in header_include_paths to ensure they are hipified
from pathlib import Path
for header_include_dir in header_include_dirs:
if os.path.isabs(header_include_dir):
header_include_dir_path = Path(header_include_dir)
else:
header_include_dir_path = Path(
os.path.join(output_directory, header_include_dir)
)
for path in header_include_dir_path.rglob("*"):
if (
path.is_file()
and _fnmatch(str(path), includes)
and (not _fnmatch(str(path), ignores))
and match_extensions(path.name, header_extensions)
):
all_files.append(str(path))
if clean_ctx is None:
clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
# Preprocessing statistics.
stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
for filepath in all_files if not hipify_extra_files_only else extra_files:
preprocess_file_and_save_result(
output_directory,
filepath,
all_files,
header_include_dirs,
stats,
hip_clang_launch,
is_pytorch_extension,
clean_ctx,
show_progress,
)
print(
bcolors.OKGREEN
+ "Successfully preprocessed all matching files."
+ bcolors.ENDC,
file=sys.stderr,
)
# Show detailed summary
if show_detailed:
compute_stats(stats)
return HIPIFY_FINAL_RESULT
# SPDX-License-Identifier: MIT
from packaging import version
from packaging.version import Version
import importlib
from typing import Any, Callable, Optional, Union, List, get_args, get_origin
aiter_lib = None
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
import torch
try:
return _is_torch_equal_or_newer(str(torch.__version__), target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version("torch")) >= Version(target)
# Helper function used in testing.
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
torch_version = version.parse(torch_version)
return torch_version >= version.parse(target)
MANUAL_SCHEMA_OPS = [
"register_graph_buffers",
# "module_moe_ck2stages",
# "mha_fwd",
# "fmha_v3_fwd",
# "mha_varlen_fwd",
# "mha_bwd",
# "fmha_v3_bwd",
# "mha_varlen_bwd",
# "fmha_v3_varlen_bwd",
# "fmha_v3_varlen_fwd",
# "mha_batch_prefill",
"hipb_findallsols",
"rocb_findallsols",
"_ActivationType",
"_QuantType",
"init_custom_ar",
# "greedy_sample",
# "random_sample",
# "mixed_sample",
# "exponential",
]
NONE_WRAPPED_OP = [
"hipb_create_extension",
# "hipb_destroy_extension",
# "getHipblasltKernelName",
# "rocb_create_extension",
# "rocb_destroy_extension",
"get_meta_buffer_ipc_handle",
"get_graph_buffer_ipc_meta",
"asm_moe_get_solutions",
"ck_moe_get_solutions",
"_ActivationType",
"_QuantType",
"get_moe_asm_solution",
# "allocate_meta_buffer",
# "dispose",
# "meta_size",
# "get_padded_m",
# "compile_mha_fwd",
# "compile_mha_bwd",
"init_custom_qr",
# "qr_max_size",
# "qr_destroy",
# "qr_open_handles",
# "qr_get_handle",
]
def generate_schema(func, mutates_args: Union[list[str], str] = "unknown") -> str:
import inspect
import torch
sig = inspect.signature(func)
parameters = []
for idx, (name, param) in enumerate(sig.parameters.items()):
param_type = param.annotation
flag = True
is_mutates = True
if mutates_args != "unknown" and name not in mutates_args:
is_mutates = False
if param_type is torch.Tensor:
if is_mutates:
type_str = f"Tensor(a{idx}!)"
else:
type_str = "Tensor"
elif param_type == Optional[torch.Tensor]:
if is_mutates:
type_str = f"Tensor(a{idx}!)?"
else:
type_str = "Tensor?"
elif get_origin(param_type) is Union and torch.Tensor in get_args(param_type):
if is_mutates:
type_str = f"Tensor(a{idx}!)?"
else:
type_str = "Tensor?"
elif param_type in (torch.SymInt, int):
type_str = "SymInt"
elif param_type in (float, bool, str):
type_str = param_type.__name__
elif param_type == Optional[torch.Generator]:
type_str = "Generator?"
elif (
get_origin(param_type) in (list, List)
and get_args(param_type)[0] is torch.Tensor
):
if is_mutates:
type_str = f"Tensor(a{idx}!)[]"
else:
type_str = "Tensor[]"
elif get_origin(param_type) in (list, List) and get_args(param_type)[0] is int:
type_str = "int[]"
elif param_type == Optional[torch.dtype]:
type_str = "ScalarType?"
else:
type_str = "*"
flag = False
if flag:
param_str = f"{type_str} {name}"
if param.default != inspect.Parameter.empty:
if param.default is None:
param_str += "=None"
else:
param_str += f"={param.default}"
else:
param_str = f"{type_str} "
parameters.append(param_str)
return_annotation = sig.return_annotation
return_type = ""
if return_annotation is type(None) or return_annotation is None:
return_type = "()"
elif return_annotation is torch.Tensor:
return_type = "Tensor"
elif (
get_origin(return_annotation) is list and get_args(return_annotation)[0] is int
):
return_type = "int[]"
elif return_annotation is int:
return_type = "int"
elif return_annotation is float:
return_type = "float"
elif return_annotation is bool:
return_type = "bool"
elif (
get_origin(return_annotation) is list
and get_args(return_annotation)[0] is torch.Tensor
):
return_type = "Tensor[]"
elif get_origin(return_annotation) is tuple:
args = get_args(return_annotation)
type_strings = []
for arg in args:
if arg is torch.Tensor:
type_strings.append("Tensor")
elif arg is int:
type_strings.append("int")
elif arg is float:
type_strings.append("float")
elif arg is bool:
type_strings.append("bool")
return_type = f"({', '.join(type_strings)})"
else:
return_type = "Any"
schema = f"({', '.join(parameters)}) -> {return_type}"
return schema
def torch_compile_guard(
mutates_args: Union[list[str], str] = "unknown",
device: str = "cpu",
calling_func_: Optional[Callable[..., Any]] = None,
gen_fake: Optional[Callable[..., Any]] = None,
):
def decorator(func):
# In core.py, we calling wrapper, but actually we need use aiter.op func
calling_func = calling_func_ if calling_func_ is not None else func
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
try:
import torch
from torch.library import Library
import inspect
except ImportError:
return wrapper
if calling_func.__name__ in NONE_WRAPPED_OP:
return wrapper
def wrapper_register(calling_func):
import inspect
import torch
import torch.library
from torch.library import Library
global aiter_lib
aiter_lib = Library("aiter", "FRAGMENT") if aiter_lib is None else aiter_lib
schema = ""
if calling_func.__name__ in MANUAL_SCHEMA_OPS:
schema = generate_schema(calling_func)
else:
sig = inspect.signature(calling_func)
if hasattr(torch.library, "infer_schema"):
schema = torch.library.infer_schema(
calling_func, mutates_args=mutates_args
)
else:
# for pytorch 2.4
import torch._custom_op.impl
# torch 2.4 not support mutates "unknown" for inplace all param
if mutates_args == "unknown":
mutates_args_custom = []
for param_name, param in sig.parameters.items():
if param.annotation == torch.Tensor:
mutates_args_custom.append(param_name)
schema = torch._custom_op.impl.infer_schema(
calling_func, mutates_args_custom
)
return schema
schema = wrapper_register(calling_func)
sig = inspect.signature(calling_func)
input_is_tensor = False
parameters = list(sig.parameters.values())
if parameters:
first_param = parameters[0]
if (
first_param.annotation is not inspect.Parameter.empty
and first_param.annotation is torch.Tensor
):
input_is_tensor = True
input_part, output_part = schema.split("->", 1)
if input_is_tensor:
new_input = input_part
else:
if not sig.parameters:
new_input = "(Tensor dummy)"
else:
new_input = "(Tensor dummy, " + input_part[1:]
return_non_tensor = False
return_annotation = sig.return_annotation
if return_annotation in [int, bool, float]:
output_part = "(Tensor, " + output_part + ")"
return_non_tensor = True
schema = f"{new_input} -> {output_part}".strip()
loadName = calling_func.__name__
def wrapper_custom(*args, **kwargs):
result = (
getattr(torch.ops.aiter, f"{loadName}")(*args, **kwargs)
if input_is_tensor
else getattr(torch.ops.aiter, f"{loadName}")(
torch.empty(1, device=device), *args, **kwargs
)
)
return result[1] if return_non_tensor else result
if hasattr(torch.ops.aiter, loadName):
return wrapper_custom
def abstract_impl(*args, **kwargs):
if gen_fake is not None:
if return_non_tensor:
return torch.empty(1, device=device), gen_fake(*args, **kwargs)
else:
return gen_fake(*args, **kwargs)
if return_non_tensor:
return torch.empty(1, device=device), calling_func(*args, **kwargs)
return calling_func(*args, **kwargs)
def outer_wrapper(*args, **kwargs):
return (
wrapper(*args, **kwargs)
if not return_non_tensor
else (torch.empty(1, device=device), wrapper(*args, **kwargs))
)
def abstract_impl_dummy(dummy, *args, **kwargs):
if gen_fake is not None:
if return_non_tensor:
return torch.empty(1, device=device), gen_fake(*args, **kwargs)
else:
return gen_fake(*args, **kwargs)
if return_non_tensor:
return torch.empty(1, device=device), calling_func(*args, **kwargs)
return calling_func(*args, **kwargs)
def outer_wrapper_dummy(dummy, *args, **kwargs):
return (
wrapper(*args, **kwargs)
if not return_non_tensor
else (torch.empty(1, device=device), wrapper(*args, **kwargs))
)
custom_func = outer_wrapper
fake_func = abstract_impl
if not input_is_tensor:
custom_func = outer_wrapper_dummy
fake_func = abstract_impl_dummy
if not hasattr(torch.ops.aiter, calling_func.__name__):
if is_torch_equal_or_newer("2.8.0"):
tags = ()
else:
tags = (torch.Tag.needs_fixed_stride_order,)
op_schema = f"aiter::{loadName}" + schema
aiter_lib.define(op_schema, tags=tags)
aiter_lib.impl(f"aiter::{loadName}", custom_func, dispatch_key="CUDA")
aiter_lib.impl(f"aiter::{loadName}", custom_func, dispatch_key="CPU")
aiter_lib._register_fake(f"{loadName}", fake_func)
return wrapper_custom
return decorator
\ No newline at end of file
# SPDX-License-Identifier: MIT
# user interface
import torch
import aiter
from aiter import dtypes
import triton
import triton.language as tl
import functools
from .jit.utils.chip_info import get_cu_num
@triton.jit
def _fwd_kernel_stage2_asm(
Mid_O,
Mid_lse,
O,
qo_indptr,
kv_indptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
mgc: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_qo_offs = tl.program_id(2)
cur_qo_start = tl.load(qo_indptr + cur_batch)
cur_qo_end = tl.load(qo_indptr + cur_batch + 1)
cur_qo = cur_qo_start + cur_qo_offs
if cur_qo > cur_qo_end:
return
cur_kv_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = (cur_qo * stride_mid_ob + cur_head * stride_mid_oh) * Lv + offs_d
offs_logic = cur_qo * stride_mid_ob + cur_head * stride_mid_oh
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.maximum(mgc, tl.cdiv(cur_kv_seq_len, NUM_KV_SPLITS))
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_kv_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(
Mid_O + offs_v + split_kv_id * stride_mid_os * Lv,
mask=mask_d,
other=0.0,
)
tlogic = tl.load(Mid_lse + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
O + cur_qo * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
@functools.lru_cache()
def get_meta_param(num_kv_splits, device, bs, nhead):
if num_kv_splits is None:
cu_num = get_cu_num()
num_kv_splits = min(16, max(1, cu_num // bs))
get_mgc = {16: 64, 128: 16}
assert nhead in get_mgc, f"{nhead=} not supported"
mgc = get_mgc[nhead]
return num_kv_splits, mgc
def mla_decode_fwd(
q,
kv_buffer,
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_q,
sm_scale=None, # 1.0 / (qk_head_dim**0.5)
logit_cap=0.0,
num_kv_splits=None, # for experts only!!!
):
device = q.device
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape
if sm_scale is None:
sm_scale = 1.0 / (qk_head_dim**0.5)
total_s, nhead, v_head_dim = o.shape
bs = qo_indptr.shape[0] - 1
num_kv_splits, mgc = get_meta_param(num_kv_splits, device, bs, nhead)
if nhead == 16:
logits = torch.empty(
(total_s, num_kv_splits, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
assert (
max_seqlen_q == 1
), f"Assertion: max_seqlen_q should be 1 when n_head=16, but got {max_seqlen_q}"
elif nhead == 128:
logits = (
o.view((total_s, num_kv_splits, nhead, v_head_dim))
if num_kv_splits == 1
else torch.empty(
(total_s, num_kv_splits, nhead, v_head_dim),
dtype=dtypes.fp32,
device=device,
)
)
else:
assert False, f"{nhead=} not supported"
attn_lse = torch.empty(
(total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
)
aiter.mla_decode_stage1_asm_fwd(
q,
kv_buffer,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_q,
sm_scale,
logits,
attn_lse,
)
if num_kv_splits == 1 and nhead == 128:
return logits.view(total_s, nhead, v_head_dim), attn_lse
Lv = v_head_dim
BLOCK_DV = triton.next_power_of_2(Lv)
grid = (bs, nhead, max_seqlen_q)
extra_kargs = {"waves_per_eu": 4}
_fwd_kernel_stage2_asm[grid](
logits,
attn_lse,
o,
qo_indptr,
kv_indptr,
attn_lse.stride(0),
attn_lse.stride(2),
attn_lse.stride(1),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=num_kv_splits,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
mgc=mgc,
num_warps=4,
num_stages=2,
**extra_kargs,
)
return logits, attn_lse
def mla_prefill_fwd(
q, # [num_seqs, num_heads, head_size]
kv_buffer, # [num_page, page_size, num_kv_heads, kv_lora_rank + qk_rope_head_dim]
o, # [num_seqs, num_heads, v_head_dim]
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_q,
sm_scale=None, # 1.0 / (qk_head_dim**0.5)
logit_cap=0.0,
num_kv_splits=None, # for experts only!!!
):
device = q.device
assert logit_cap <= 0, f"{logit_cap=} is not support yet"
if sm_scale is None:
sm_scale = 1.0 / (qk_head_dim**0.5)
num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape
bs, nhead, v_head_dim = o.shape
num_kv_splits = 1
logits = o.view(bs, num_kv_splits, nhead, v_head_dim)
# logits = torch.empty(
# (bs, num_kv_splits, nhead, v_head_dim), dtype=dtypes.fp32, device=device
# )
attn_lse = torch.empty(
(bs, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
)
aiter.mla_prefill_asm_fwd(
q,
kv_buffer,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_q,
sm_scale,
logits,
attn_lse,
)
# return logits.view(bs, nhead, v_head_dim).to(o.dtype), attn_lse
return o.view(bs, nhead, v_head_dim), attn_lse
# SPDX-License-Identifier: MIT
import logging
import torch
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple, List
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
class MoeSolutionType:
MOE_C = "moe_c"
ASM = "asm"
TRITON = "triton"
CK = "ck"
class MoeQuantType:
"""Quantization types supported by get_aiter_moe_config / aiter_moe."""
W16A16 = "w16a16"
W4A16 = "w4a16"
W8A8 = "w8a8"
W4A8 = "w4a8"
@dataclass
class AiterMoeConfig:
"""Config returned by :func:`get_aiter_moe_config`.
Attributes:
quant_type: The quantization type this config was obtained for.
solution_type: Which backend to use (MoeSolutionType constant), or
None if no solution was found.
config: Backend-specific config dict (opaque to the caller).
"""
quant_type: Optional[str] = None
solution_type: Optional[str] = None
config: Optional[Dict[str, Any]] = None
def _pick_closest_config(configs: Dict[int, Any], m: int) -> Dict[str, Any]:
return configs[min(configs.keys(), key=lambda x: abs(x - m))]
def _try_get_moe_c_config(
quant_type: str,
m: int,
e: int,
n: int,
block_size: int,
) -> Optional[Dict[str, Any]]:
try:
if quant_type == MoeQuantType.W4A16:
from .fused_moe_c import get_moe_configs_marlin
configs = get_moe_configs_marlin(
E=e,
N=n,
dtype="int4_w4a16",
is_bottom=False,
use_moe_wna16_cuda=True,
)
elif quant_type == MoeQuantType.W8A8:
from .fused_moe_c import get_moe_configs_marlin
configs = get_moe_configs_marlin(
E=e,
N=n,
dtype="int8_w8a8",
is_bottom=False,
use_moe_wna16_cuda=True,
)
elif quant_type == MoeQuantType.W4A8:
from .fused_moe_c import get_moe_configs_marlin
configs = get_moe_configs_marlin(
E=e,
N=n,
dtype="int8_w4a8",
block_n = block_size,
block_k = block_size,
is_bottom=False,
use_moe_wna16_cuda=True,
)
else:
return None
if configs is None:
return None
return _pick_closest_config(configs, m)
except Exception as exc:
logger.debug("moe_c config lookup failed for %s: %s", quant_type, exc)
return None
def _try_get_asm_config(
quant_type: str,
m: int,
e: int,
n: int,
k: int,
top_k: int,
block_size: Optional[int],
) -> Optional[Dict[str, Any]]:
try:
from .fused_moe_asm_wna16 import get_moe_asm_solution, MoeQuantType as AsmMoeQuantType
from .jit.utils.chip_info import get_gfx
arch = get_gfx()
if quant_type == MoeQuantType.W4A16:
from .fused_moe_asm_wna16 import decode_sol_w4a16, decode_sol_w4a16_gw32
if block_size == 32:
if top_k > 8 or n != 256 or k != 7168:
return None
else:
return decode_sol_w4a16_gw32()
solution = get_moe_asm_solution(
arch=arch,
token=m,
inter_dim=n,
model_dim=k,
expert=e,
topk=top_k,
quant_type=AsmMoeQuantType.INT4_W4A16,
)
if solution == "default":
return None
return decode_sol_w4a16(solution)
if quant_type == MoeQuantType.W8A8:
from .fused_moe_asm_wna16 import decode_sol_0
solution = get_moe_asm_solution(
arch=arch,
token=m,
inter_dim=n,
model_dim=k,
expert=e,
topk=top_k,
quant_type=AsmMoeQuantType.INT8_W8A8,
)
if solution == "default":
return None
return decode_sol_0(solution)
if quant_type == MoeQuantType.W16A16:
from .fused_moe_asm_wna16 import decode_sol_0
solution = get_moe_asm_solution(
arch=arch,
token=m,
inter_dim=n,
model_dim=k,
expert=e,
topk=top_k,
quant_type=AsmMoeQuantType.NO_QUANT,
)
if solution == "default":
return None
return decode_sol_0(solution)
return None
except Exception as exc:
logger.debug("ASM config lookup failed for %s: %s", quant_type, exc)
return None
def _try_get_triton_config(
quant_type: str,
m: int,
e: int,
n: int,
block_size: int,
) -> Optional[Dict[str, Any]]:
try:
from .ops.triton.utils.moe_config_utils import get_moe_configs as triton_get_moe_configs
if quant_type == MoeQuantType.W16A16:
return {} # Non-quantized; no tuned config lookup needed
dtype_name = {
MoeQuantType.W4A16: "int4_w4a16",
MoeQuantType.W8A8: "int8_w8a8",
}.get(quant_type)
if dtype_name is None:
return None
configs = triton_get_moe_configs(
E=e,
N=n,
dtype=dtype_name,
block_n=0,
block_k=block_size if block_size else 0,
is_bottom=False,
)
if configs is None:
return None
return _pick_closest_config(configs, m)
except Exception as exc:
logger.debug("Triton config lookup failed for %s: %s", quant_type, exc)
return None
def _try_get_ck_config(
quant_type: str,
m: int,
e: int,
n: int,
k: int,
top_k: int,
block_shape: Optional[List[int]],
) -> Optional[Dict[str, Any]]:
try:
if quant_type != MoeQuantType.W8A8:
return None
from .fused_moe_ck import get_moe_ck_solution_id, MoeQuantType as CkMoeQuantType
from .jit.utils.chip_info import get_gfx
arch = get_gfx()
q_size_n = block_shape[0] if block_shape is not None else 0
q_size_k = block_shape[1] if block_shape is not None else 0
solution_id = get_moe_ck_solution_id(
arch,
CkMoeQuantType.INT8_W8A8,
m,
n,
k,
e,
top_k,
q_size_n,
q_size_k,
)
return {"solution_id": solution_id}
except Exception as exc:
logger.debug("CK config lookup failed for %s: %s", quant_type, exc)
return None
def get_aiter_moe_config(
M: int, # Number of tokens (input sequence length)
E: int, # Number of experts
N1: int, # GEMM1 output dimension, typically equal to (moe_intermediate_size / TP * 2)
N2: int, # GEMM2 output dimension, typically equal to hidden_size
K: int, # GEMM1 input dimension, typically equal to hidden_size; for GEMM2, K typically equal to (moe_intermediate_size / TP)
top_k: int,
block_size: int,
dtype: torch.dtype,
quant_type: str,
) -> Tuple[bool, AiterMoeConfig]:
"""Get the best backend config for a MOE problem.
Currently supported quant types:
- ``MoeQuantType.W16A16`` (non-quantized)
- ``MoeQuantType.W4A16``
- ``MoeQuantType.W8A8``
- ``MoeQuantType.W4A8``
Backend priority:
- ``w16a16``: asm > triton
- ``w4a16``: moe_c > asm > triton
- ``w8a8``: asm > moe_c > triton > ck
- ``w4a8``: moe_c
"""
n = N1 / 2
block_shape = [0, block_size] if block_size else None
if quant_type == MoeQuantType.W4A16:
if dtype == torch.float16:
candidates = [
(MoeSolutionType.MOE_C, lambda: _try_get_moe_c_config(quant_type, M, E, n, block_size)),
]
elif dtype == torch.bfloat16:
candidates = [
(MoeSolutionType.ASM, lambda: _try_get_asm_config(quant_type, M, E, n, K, top_k, block_size)),
(MoeSolutionType.TRITON, lambda: _try_get_triton_config(quant_type, M, E, n, block_size)),
]
else:
raise ValueError(f"Unsupported dtype: {dtype}")
elif quant_type == MoeQuantType.W8A8:
if block_size == 0: # Channel wise choose MOE_C
candidates = [
(MoeSolutionType.MOE_C, lambda: _try_get_moe_c_config(quant_type, M, E, n, block_size)),
(MoeSolutionType.TRITON, lambda: _try_get_triton_config(quant_type, M, E, n, block_size)),
# (MoeSolutionType.CK, lambda: _try_get_ck_config(quant_type, M, E, n, K, top_k, block_shape)),
]
else: # Block wise choose ASM
candidates = [
(MoeSolutionType.ASM, lambda: _try_get_asm_config(quant_type, M, E, n, K, top_k, block_size)),
]
elif quant_type == MoeQuantType.W4A8:
candidates = [
(MoeSolutionType.MOE_C, lambda: _try_get_moe_c_config(quant_type, M, E, n, block_size)),
# (MoeSolutionType.ASM, lambda: _try_get_asm_config(quant_type, M, E, n, K, top_k)),
]
elif quant_type == MoeQuantType.W16A16:
candidates = [
(MoeSolutionType.ASM, lambda: _try_get_asm_config(quant_type, M, E, n, K, top_k, None)),
(MoeSolutionType.TRITON, lambda: _try_get_triton_config(quant_type, M, E, n, block_size)),
]
else:
raise ValueError(f"Unsupported quant_type: {quant_type}")
for solution_type, get_config in candidates:
config = get_config()
if config is not None:
return True, AiterMoeConfig(
quant_type=quant_type,
solution_type=solution_type,
config=config,
)
return False, AiterMoeConfig(quant_type=quant_type)
def aiter_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
moe_config: AiterMoeConfig,
inplace: Optional[bool] = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
) -> torch.Tensor:
"""Execute MOE using the backend and quant type described by *moe_config*."""
if moe_config.solution_type is None or moe_config.quant_type is None:
raise ValueError(
"moe_config has no valid solution_type/quant_type. "
"Call get_aiter_moe_config first and check the status."
)
use_int4_w4a16 = moe_config.quant_type == MoeQuantType.W4A16
use_int8_w8a8 = moe_config.quant_type == MoeQuantType.W8A8
use_int8_w4a8 = moe_config.quant_type == MoeQuantType.W4A8
if moe_config.solution_type == MoeSolutionType.MOE_C:
from .fused_moe_c import moe_c_fused_experts
return moe_c_fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8,
use_int8_w4a8=use_int8_w4a8,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
routed_scaling_factor=routed_scaling_factor
)
if moe_config.solution_type == MoeSolutionType.ASM:
from .fused_moe_asm_wna16 import fused_experts_asm_impl
cfg = moe_config.config
solution_id = f"{cfg['SOL_ID1']}+{cfg['SOL_ID2']}"
return fused_experts_asm_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
dtype=hidden_states.dtype,
inplace=inplace,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
solution_id=solution_id,
routed_scaling_factor=routed_scaling_factor
)
if moe_config.solution_type == MoeSolutionType.TRITON:
from .ops.triton.fused_moe import fused_experts_impl
# W8A8 channel-wise (block_shape=None) requires per_channel_quant=True
per_channel_quant = use_int8_w8a8 and block_shape is None
return fused_experts_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
odtype=hidden_states.dtype,
inplace=inplace,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8,
activation=activation,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
routed_scaling_factor=routed_scaling_factor
)
if moe_config.solution_type == MoeSolutionType.CK:
from .fused_moe_ck import run_fused_experts_ck_impl
solution_id = moe_config.config["solution_id"]
return run_fused_experts_ck_impl(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
odtype=hidden_states.dtype,
inplace=inplace,
use_int8_w8a8=use_int8_w8a8,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
routed_scaling_factor=routed_scaling_factor,
solution_id=solution_id,
)
raise ValueError(f"Unknown solution_type: {moe_config.solution_type}")
def get_aiter_moe_config_w4a16(
M: int,
E: int,
N1: int,
N2: int,
K: int,
top_k: int,
block_size: int,
dtype: torch.dtype,
) -> Tuple[bool, AiterMoeConfig]:
"""Backward-compatible wrapper for w4a16 config lookup."""
return get_aiter_moe_config(M, E, N1, N2, K, top_k, block_size, dtype, MoeQuantType.W4A16)
def aiter_moe_w4a16(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
moe_config: AiterMoeConfig,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_zp: Optional[torch.Tensor] = None,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
"""Backward-compatible wrapper for w4a16 execution."""
return aiter_moe(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_config=moe_config,
activation=activation,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 86
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 121
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 98
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 183
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 146
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 160
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 160
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"2": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"3": {
"BLOCK_SIZE_M": 16,
"MODE": 42
},
"4": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"5": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"6": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"7": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"8": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"9": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"10": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"11": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"12": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"13": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"14": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"15": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"16": {
"BLOCK_SIZE_M": 16,
"MODE": 38
},
"32": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"64": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"128": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"256": {
"BLOCK_SIZE_M": 16,
"MODE": 46
},
"512": {
"BLOCK_SIZE_M": 16,
"MODE": 43
},
"1024": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"2048": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"4096": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"8192": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"16384": {
"BLOCK_SIZE_M": 32,
"MODE": 86
},
"32768": {
"BLOCK_SIZE_M": 32,
"MODE": 86
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment