Commit 194e19bd authored by baominghelly's avatar baominghelly
Browse files

issue 787 - Merge from main and resolve conflict

parents 147a4ac7 215d1932
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "topkrouter_metax.h"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t correction_bias_desc) {
auto result = TopkrouterInfo::create(x_desc);
CHECK_RESULT(result);
auto info = result.take();
if (info.x_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
namespace {
template <int BLOCK_SIZE = 128>
infiniStatus_t launch_topkrouter(float *d_values_out, int *d_indices_out, const void *d_input, const float *d_correction_bias,
const float routed_scaling_factor, const size_t N, const size_t width, const size_t topk, infiniDtype_t xtype,
hcStream_t stream) {
const int block_threads = BLOCK_SIZE;
dim3 blocks(N);
dim3 threads(block_threads);
if (xtype == INFINI_DTYPE_F32) {
topkrouter_kernel<float, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (float *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_F16) {
topkrouter_kernel<half, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (half *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else if (xtype == INFINI_DTYPE_BF16) {
topkrouter_kernel<cuda_bfloat16, BLOCK_SIZE><<<blocks, threads, 0, stream>>>(d_values_out, d_indices_out, (cuda_bfloat16 *)d_input, d_correction_bias, routed_scaling_factor, N, width, topk);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
}; // namespace
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
float *values,
int *indices,
const void *x,
const float *correction_bias,
const float routed_scaling_factor,
const size_t topk,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
size_t N = _info.N;
size_t width = _info.width; // 256
// size_t n_routed_experts = 256;
// size_t n_group = 8;
// size_t topk_group = 4;
auto cuda_stream = reinterpret_cast<hcStream_t>(stream);
if (256 == width) {
launch_topkrouter<256>(values, indices, x, correction_bias, routed_scaling_factor, N, width, topk, _info.xtype, cuda_stream);
} else {
return INFINI_STATUS_BAD_PARAM;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::topkrouter::metax
...@@ -2,9 +2,17 @@ ...@@ -2,9 +2,17 @@
#include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../cuda/kernel.cuh"
#include "topkrouter_nvidia.cuh" #include "topkrouter_nvidia.cuh"
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh> #include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include "../cuda/kernel.cuh"
namespace op::topkrouter::nvidia { namespace op::topkrouter::nvidia {
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/topkrouter_nvidia.cuh" #include "nvidia/topkrouter_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API
#include "metax/topkrouter_metax.h"
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/topkrouter_kunlun.h" #include "kunlun/topkrouter_kunlun.h"
#endif #endif
...@@ -30,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i ...@@ -30,6 +33,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
...@@ -56,6 +62,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript ...@@ -56,6 +62,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia); GET(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun); GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
...@@ -85,6 +94,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void ...@@ -85,6 +94,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
...@@ -111,6 +123,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip ...@@ -111,6 +123,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#ifdef ENABLE_QY_API #ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia); DESTROY(INFINI_DEVICE_QY, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun); DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
thread_local infiniDevice_t CURRENT_DEVICE_TYPE = INFINI_DEVICE_CPU; thread_local infiniDevice_t CURRENT_DEVICE_TYPE = INFINI_DEVICE_CPU;
thread_local int CURRENT_DEVICE_ID = 0; thread_local int CURRENT_DEVICE_ID = 0;
thread_local infiniDevice_t PREVIOUS_NON_CPU_DEVICE_TYPE = INFINI_DEVICE_TYPE_COUNT;
thread_local int PREVIOUS_NON_CPU_DEVICE_ID = 0;
__C infiniStatus_t infinirtInit() { __C infiniStatus_t infinirtInit() {
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
...@@ -96,6 +98,16 @@ __C infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count) { ...@@ -96,6 +98,16 @@ __C infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count) {
} }
__C infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id) { __C infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id) {
bool skip_set = CURRENT_DEVICE_TYPE == INFINI_DEVICE_CPU && device == PREVIOUS_NON_CPU_DEVICE_TYPE && device_id == PREVIOUS_NON_CPU_DEVICE_ID;
if (CURRENT_DEVICE_TYPE != INFINI_DEVICE_CPU) {
PREVIOUS_NON_CPU_DEVICE_TYPE = CURRENT_DEVICE_TYPE;
PREVIOUS_NON_CPU_DEVICE_ID = CURRENT_DEVICE_ID;
}
if (skip_set) {
CURRENT_DEVICE_TYPE = device;
CURRENT_DEVICE_ID = device_id;
return INFINI_STATUS_SUCCESS;
}
INFINIRT_CALL_DEVICE_API_AND(device, setDevice, (device_id), INFINIRT_CALL_DEVICE_API_AND(device, setDevice, (device_id),
{ CURRENT_DEVICE_TYPE = device; { CURRENT_DEVICE_TYPE = device;
CURRENT_DEVICE_ID = device_id; }); CURRENT_DEVICE_ID = device_id; });
......
#include "infinirt_moore.h" #include "infinirt_moore.h"
#include "../../utils.h" #include "../../utils.h"
#include <chrono>
#include <musa_runtime.h> #include <musa_runtime.h>
#include <musa_runtime_api.h> #include <musa_runtime_api.h>
...@@ -83,23 +82,7 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) { ...@@ -83,23 +82,7 @@ infiniStatus_t eventDestroy(infinirtEvent_t event) {
} }
infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) { infiniStatus_t eventElapsedTime(float *ms_ptr, infinirtEvent_t start, infinirtEvent_t end) {
// MUSA may not have direct musaEventElapsedTime API CHECK_MUSART(musaEventElapsedTime(ms_ptr, (musaEvent_t)start, (musaEvent_t)end));
// Use a fallback method: synchronize events and measure CPU time difference
// Note: This includes synchronization overhead, so timing may not be as accurate
// as native GPU event timing, but it allows benchmarking to work
// Synchronize start event and record CPU time
CHECK_MUSART(musaEventSynchronize((musaEvent_t)start));
auto start_cpu_time = std::chrono::steady_clock::now();
// Synchronize end event and record CPU time
CHECK_MUSART(musaEventSynchronize((musaEvent_t)end));
auto end_cpu_time = std::chrono::steady_clock::now();
// Calculate elapsed time in milliseconds
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_cpu_time - start_cpu_time);
*ms_ptr = static_cast<float>(duration.count()) / 1000.0f;
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -22,7 +22,6 @@ from .utils.compare_utils import ( ...@@ -22,7 +22,6 @@ from .utils.compare_utils import (
) )
from .utils.json_utils import save_json_report from .utils.json_utils import save_json_report
from .utils.tensor_utils import ( from .utils.tensor_utils import (
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
convert_infinicore_to_torch, convert_infinicore_to_torch,
rearrange_tensor, rearrange_tensor,
...@@ -37,7 +36,6 @@ __all__ = [ ...@@ -37,7 +36,6 @@ __all__ = [
# Core types and classes # Core types and classes
"BaseOperatorTest", "BaseOperatorTest",
"CaseResult", "CaseResult",
"ConsolePrinter",
"GenericTestRunner", "GenericTestRunner",
"InfiniDeviceEnum", "InfiniDeviceEnum",
"InfiniDeviceNames", "InfiniDeviceNames",
...@@ -47,7 +45,7 @@ __all__ = [ ...@@ -47,7 +45,7 @@ __all__ = [
"TestCase", "TestCase",
"TestConfig", "TestConfig",
"TestDriver", "TestDriver",
"TestReporter", "TestSummary",
"TestRunner", "TestRunner",
"TestTiming", "TestTiming",
# Core functions # Core functions
......
...@@ -15,6 +15,8 @@ def synchronize_device(torch_device): ...@@ -15,6 +15,8 @@ def synchronize_device(torch_device):
torch.npu.synchronize() torch.npu.synchronize()
elif torch_device == "mlu": elif torch_device == "mlu":
torch.mlu.synchronize() torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
# ================================================================= # =================================================================
......
...@@ -71,17 +71,32 @@ class TestTensor(CTensor): ...@@ -71,17 +71,32 @@ class TestTensor(CTensor):
torch_shape.append(shape[i]) torch_shape.append(shape[i])
if mode == "random": if mode == "random":
# For integer types, use randint instead of rand # For integer types, use randint instead of rand
if dt in [InfiniDtype.I8, InfiniDtype.I16, InfiniDtype.I32, InfiniDtype.I64, if dt in [
InfiniDtype.U8, InfiniDtype.U16, InfiniDtype.U32, InfiniDtype.U64, InfiniDtype.I8,
InfiniDtype.BYTE, InfiniDtype.BOOL]: InfiniDtype.I16,
InfiniDtype.I32,
InfiniDtype.I64,
InfiniDtype.U8,
InfiniDtype.U16,
InfiniDtype.U32,
InfiniDtype.U64,
InfiniDtype.BYTE,
InfiniDtype.BOOL,
]:
randint_low = -2000000000 if randint_low is None else randint_low randint_low = -2000000000 if randint_low is None else randint_low
randint_high = 2000000000 if randint_high is None else randint_high randint_high = 2000000000 if randint_high is None else randint_high
self._torch_tensor = torch.randint( self._torch_tensor = torch.randint(
randint_low, randint_high, torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device] randint_low,
randint_high,
torch_shape,
dtype=to_torch_dtype(dt),
device=torch_device_map[device],
) )
else: else:
self._torch_tensor = torch.rand( self._torch_tensor = torch.rand(
torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device] torch_shape,
dtype=to_torch_dtype(dt),
device=torch_device_map[device],
) )
elif mode == "zeros": elif mode == "zeros":
self._torch_tensor = torch.zeros( self._torch_tensor = torch.zeros(
...@@ -431,6 +446,8 @@ def synchronize_device(torch_device): ...@@ -431,6 +446,8 @@ def synchronize_device(torch_device):
torch.npu.synchronize() torch.npu.synchronize()
elif torch_device == "mlu": elif torch_device == "mlu":
torch.mlu.synchronize() torch.mlu.synchronize()
elif torch_device == "musa":
torch.musa.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
...@@ -463,7 +480,14 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): ...@@ -463,7 +480,14 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
def filter_tensor_dtypes_by_device(device, tensor_dtypes): def filter_tensor_dtypes_by_device(device, tensor_dtypes):
if device in (InfiniDeviceEnum.CPU, InfiniDeviceEnum.NVIDIA, InfiniDeviceEnum.METAX, InfiniDeviceEnum.ASCEND, InfiniDeviceEnum.ILUVATAR, InfiniDeviceEnum.CAMBRICON): if device in (
InfiniDeviceEnum.CPU,
InfiniDeviceEnum.NVIDIA,
InfiniDeviceEnum.METAX,
InfiniDeviceEnum.ASCEND,
InfiniDeviceEnum.ILUVATAR,
InfiniDeviceEnum.CAMBRICON,
):
return tensor_dtypes return tensor_dtypes
else: else:
# 过滤掉 torch.bfloat16 # 过滤掉 torch.bfloat16
......
...@@ -19,7 +19,7 @@ from libinfiniop import ( ...@@ -19,7 +19,7 @@ from libinfiniop import (
InfiniDtypeNames, InfiniDtypeNames,
InfiniDeviceNames, InfiniDeviceNames,
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
torch_device_map torch_device_map,
) )
# ============================================================================== # ==============================================================================
...@@ -29,12 +29,14 @@ from libinfiniop import ( ...@@ -29,12 +29,14 @@ from libinfiniop import (
_TEST_CASES_ = [ _TEST_CASES_ = [
# x_shape, x_stride, topk, routed_scaling_factor # x_shape, x_stride, topk, routed_scaling_factor
((1, 256), None, 8, 2.5), ((1, 256), None, 8, 2.5),
((2, 256), None, 8, 1.0),
] ]
# w (weight) types # w (weight) types
# Note: 'None' means the same as input dtype # Note: 'None' means the same as input dtype
# _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] # _X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16]
_X_DTYPES = [] # CPU CI _X_DTYPES = [] # CPU CI
# x types used for testing # x types used for testing
_VALUE_DTYPES = [InfiniDtype.F32] _VALUE_DTYPES = [InfiniDtype.F32]
...@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000 ...@@ -57,17 +59,34 @@ NUM_ITERATIONS = 1000
def tensorInfo(data): def tensorInfo(data):
print("data: ", data.is_contiguous(), data.device, data.dtype, data.shape, data.stride(), data.data_ptr(), hex(data.data_ptr())) print(
"data: ",
data.is_contiguous(),
data.device,
data.dtype,
data.shape,
data.stride(),
data.data_ptr(),
hex(data.data_ptr()),
)
class DeepseekV3TopkRouter(nn.Module): class DeepseekV3TopkRouter(nn.Module):
def __init__(self, correction_bias, routed_scaling_factor: float = 2.5, topk: int = 8, config=None): def __init__(
self,
correction_bias,
routed_scaling_factor: float = 2.5,
topk: int = 8,
config=None,
):
super().__init__() super().__init__()
self.config = config self.config = config
self.top_k = topk # config.num_experts_per_tok 8 self.top_k = topk # config.num_experts_per_tok 8
assert topk == 8 assert topk == 8
self.n_routed_experts = 256 # config.n_routed_experts self.n_routed_experts = 256 # config.n_routed_experts
self.routed_scaling_factor = routed_scaling_factor # config.routed_scaling_factor 2.5 self.routed_scaling_factor = (
routed_scaling_factor # config.routed_scaling_factor 2.5
)
self.n_group = 8 # config.n_group self.n_group = 8 # config.n_group
self.topk_group = 4 # config.topk_group self.topk_group = 4 # config.topk_group
self.norm_topk_prob = True # config.norm_topk_prob self.norm_topk_prob = True # config.norm_topk_prob
...@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -81,14 +100,20 @@ class DeepseekV3TopkRouter(nn.Module):
@torch.no_grad() @torch.no_grad()
def get_topk_indices(self, scores): def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256]) scores_for_choice = scores.view(
-1, self.n_routed_experts
) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256])
group_scores = ( group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) scores_for_choice.view(
-1, self.n_group, self.n_routed_experts // self.n_group
)
.topk(2, dim=-1)[0] .topk(2, dim=-1)[0]
.sum(dim=-1) .sum(dim=-1)
) )
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # Size([1, 4]) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[
1
] # Size([1, 4])
group_mask = torch.zeros_like(group_scores) # Size([1, 8]) group_mask = torch.zeros_like(group_scores) # Size([1, 8])
group_mask.scatter_(1, group_idx, 1) # Size([1, 8]) group_mask.scatter_(1, group_idx, 1) # Size([1, 8])
...@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -98,8 +123,12 @@ class DeepseekV3TopkRouter(nn.Module):
.reshape(-1, self.n_routed_experts) .reshape(-1, self.n_routed_experts)
) )
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # Size([1, 256]) scores_for_choice = scores_for_choice.masked_fill(
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[1] # Size([1, 8]) ~score_mask.bool(), 0.0
) # Size([1, 256])
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[
1
] # Size([1, 8])
return topk_indices return topk_indices
...@@ -122,21 +151,23 @@ class DeepseekV3TopkRouter(nn.Module): ...@@ -122,21 +151,23 @@ class DeepseekV3TopkRouter(nn.Module):
def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk): def torch_topkrouter(router_logits, correction_bias, routed_scaling_factor, topk):
lable_indices, lable_values = DeepseekV3TopkRouter(correction_bias, routed_scaling_factor, topk)(router_logits) lable_indices, lable_values = DeepseekV3TopkRouter(
correction_bias, routed_scaling_factor, topk
)(router_logits)
lable_indices = lable_indices.to(torch.int32) lable_indices = lable_indices.to(torch.int32)
return lable_values, lable_indices return lable_values, lable_indices
def test( def test(
handle, handle,
device, device,
x_shape, x_shape,
x_stride, x_stride,
topk, topk,
routed_scaling_factor, routed_scaling_factor,
x_dtype=InfiniDtype.F32, x_dtype=InfiniDtype.F32,
dtype=InfiniDtype.F16, dtype=InfiniDtype.F16,
sync=None, sync=None,
): ):
print( print(
f"Testing topkrouter on {InfiniDeviceNames[device]} with x_shape:{x_shape} " f"Testing topkrouter on {InfiniDeviceNames[device]} with x_shape:{x_shape} "
...@@ -146,8 +177,12 @@ def test( ...@@ -146,8 +177,12 @@ def test(
data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape) data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape)
N, width = x_shape N, width = x_shape
x = TestTensor(x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random") x = TestTensor(
correction_bias = TestTensor([x_shape[1]], [1], InfiniDtype.F32, device, mode="random") x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random"
)
correction_bias = TestTensor(
[x_shape[1]], [1], InfiniDtype.F32, device, mode="random"
)
if sync is not None: if sync is not None:
sync() sync()
...@@ -155,10 +190,7 @@ def test( ...@@ -155,10 +190,7 @@ def test(
descriptor = infiniopOperatorDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
LIBINFINIOP.infiniopCreateTopkrouterDescriptor( LIBINFINIOP.infiniopCreateTopkrouterDescriptor(
handle, handle, ctypes.byref(descriptor), x.descriptor, correction_bias.descriptor
ctypes.byref(descriptor),
x.descriptor,
correction_bias.descriptor
) )
) )
...@@ -174,8 +206,12 @@ def test( ...@@ -174,8 +206,12 @@ def test(
) )
workspace = TestWorkspace(workspace_size.value, x.device) workspace = TestWorkspace(workspace_size.value, x.device)
values = torch.zeros((N, topk), dtype=torch.float32, device=torch_device_map[x.device]) values = torch.zeros(
indices = torch.zeros((N, topk), dtype=torch.int32, device=torch_device_map[x.device]) (N, topk), dtype=torch.float32, device=torch_device_map[x.device]
)
indices = torch.zeros(
(N, topk), dtype=torch.int32, device=torch_device_map[x.device]
)
def lib_topkrouter(): def lib_topkrouter():
check_error( check_error(
...@@ -195,8 +231,9 @@ def test( ...@@ -195,8 +231,9 @@ def test(
lib_topkrouter() lib_topkrouter()
lable_values, lable_indices = torch_topkrouter(
lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk) x.actual_tensor(), correction_bias.actual_tensor(), routed_scaling_factor, topk
)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(lable_values, values, atol=atol, rtol=rtol) debug(lable_values, values, atol=atol, rtol=rtol)
......
...@@ -66,6 +66,13 @@ if has_config("cudnn") then ...@@ -66,6 +66,13 @@ if has_config("cudnn") then
add_defines("ENABLE_CUDNN_API") add_defines("ENABLE_CUDNN_API")
end end
option("cuda_arch")
set_showmenu(true)
set_description("Set CUDA GPU architecture (e.g. sm_90)")
set_values("sm_50", "sm_60", "sm_70", "sm_75", "sm_80", "sm_86", "sm_89", "sm_90", "sm_90a")
set_category("option")
option_end()
-- 寒武纪 -- 寒武纪
option("cambricon-mlu") option("cambricon-mlu")
set_default(false) set_default(false)
......
...@@ -3,6 +3,14 @@ if CUDNN_ROOT ~= nil then ...@@ -3,6 +3,14 @@ if CUDNN_ROOT ~= nil then
add_includedirs(CUDNN_ROOT .. "/include") add_includedirs(CUDNN_ROOT .. "/include")
end end
local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH")
local CUTE_ROOT = os.getenv("CUTE_ROOT") or os.getenv("CUTE_HOME") or os.getenv("CUTE_PATH")
if CUTLASS_ROOT ~= nil then
add_includedirs(CUTLASS_ROOT)
add_includedirs(CUTE_ROOT)
end
target("infiniop-nvidia") target("infiniop-nvidia")
set_kind("static") set_kind("static")
add_deps("infini-utils") add_deps("infini-utils")
...@@ -28,6 +36,11 @@ target("infiniop-nvidia") ...@@ -28,6 +36,11 @@ target("infiniop-nvidia")
target:add("linkdirs", path.directory(path.directory(nvcc_path)) .. "/lib64/stubs") target:add("linkdirs", path.directory(path.directory(nvcc_path)) .. "/lib64/stubs")
target:add("links", "cuda") target:add("links", "cuda")
local cuda_arch = get_config("cuda_arch")
if cuda_arch ~= nil then
target:add("cu-cxxflags", "-arch=", cuda_arch)
end
end end
end) end)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment