Commit 76e6817b authored by Your Name's avatar Your Name
Browse files

[Adaption]适配ollama0.11.0 for DCU

parent d5520684
...@@ -43,7 +43,7 @@ const ( ...@@ -43,7 +43,7 @@ const (
var ( var (
// Used to validate if the given ROCm lib is usable // Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"} RocmStandardLocations = []string{"/opt/dtk/lib", "/usr/lib64"}
) )
// Gather GPU information from the amdgpu driver if any supported GPUs are detected // Gather GPU information from the amdgpu driver if any supported GPUs are detected
...@@ -55,11 +55,11 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { ...@@ -55,11 +55,11 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
} }
// Opportunistic logging of driver version to aid in troubleshooting // Opportunistic logging of driver version to aid in troubleshooting
driverMajor, driverMinor, err := AMDDriverVersion() //driverMajor, driverMinor, err := AMDDriverVersion()
if err != nil { //if err != nil {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err) // slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err)
} //}
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
var visibleDevices []string var visibleDevices []string
...@@ -283,8 +283,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { ...@@ -283,8 +283,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
Name: name, Name: name,
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
MinimumMemory: rocmMinimumMemory, MinimumMemory: rocmMinimumMemory,
DriverMajor: driverMajor, //DriverMajor: driverMajor,
DriverMinor: driverMinor, //DriverMinor: driverMinor,
}, },
usedFilepath: usedFile, usedFilepath: usedFile,
index: gpuID, index: gpuID,
...@@ -413,15 +413,15 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { ...@@ -413,15 +413,15 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
// Quick check for AMD driver so we can skip amdgpu discovery if not present // Quick check for AMD driver so we can skip amdgpu discovery if not present
func AMDDetected() bool { func AMDDetected() bool {
// Some driver versions (older?) don't have a version file, so just lookup the parent dir // Some driver versions (older?) don't have a version file, so just lookup the parent dir
sysfsDir := filepath.Dir(DriverVersionFile) //sysfsDir := filepath.Dir(DriverVersionFile)
_, err := os.Stat(sysfsDir) //_, err := os.Stat(sysfsDir)
if errors.Is(err, os.ErrNotExist) { //if errors.Is(err, os.ErrNotExist) {
slog.Debug("amdgpu driver not detected " + sysfsDir) // slog.Debug("amdgpu driver not detected " + sysfsDir)
return false // return false
} else if err != nil { //} else if err != nil {
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err) // slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
return false // return false
} //}
return true return true
} }
......
...@@ -263,7 +263,8 @@ static bool cp_async_available(const int cc) { ...@@ -263,7 +263,8 @@ static bool cp_async_available(const int cc) {
static constexpr __device__ int ggml_cuda_get_physical_warp_size() { static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE; //return __AMDGCN_WAVEFRONT_SIZE;
return 32;
#else #else
return 32; return 32;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
......
...@@ -249,7 +249,8 @@ static void mul_mat_vec_q_switch_ncols_dst( ...@@ -249,7 +249,8 @@ static void mul_mat_vec_q_switch_ncols_dst(
const int sample_ratio = nsamples_dst / nsamples_x; const int sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device(); const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size; //const int warp_size = ggml_cuda_info().devices[device].warp_size;
const int warp_size = 32;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
GGML_ASSERT(!ids || ncols_dst == 1); GGML_ASSERT(!ids || ncols_dst == 1);
...@@ -257,7 +258,11 @@ static void mul_mat_vec_q_switch_ncols_dst( ...@@ -257,7 +258,11 @@ static void mul_mat_vec_q_switch_ncols_dst(
case 1: case 1:
{ {
constexpr int c_ncols_dst = 1; constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); //std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
const int64_t nblocks = nrows_x;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, 4, 1);
std::pair<dim3, dim3> dims(block_nums, block_dims);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
......
...@@ -46,7 +46,7 @@ ...@@ -46,7 +46,7 @@
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t #define cudaDeviceProp hipDeviceProp_t_v2
#define cudaDeviceSynchronize hipDeviceSynchronize #define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t #define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
...@@ -61,7 +61,7 @@ ...@@ -61,7 +61,7 @@
#define cudaFreeHost hipHostFree #define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice #define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount #define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties #define cudaGetDeviceProperties hipGetDeviceProperties_v2
#define cudaGetErrorString hipGetErrorString #define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError #define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister #define cudaHostRegister hipHostRegister
......
...@@ -46,9 +46,9 @@ if (GGML_HIP_ROCWMMA_FATTN) ...@@ -46,9 +46,9 @@ if (GGML_HIP_ROCWMMA_FATTN)
endif() endif()
endif() endif()
if (${hip_VERSION} VERSION_LESS 5.5) #if (${hip_VERSION} VERSION_LESS 5.5)
message(FATAL_ERROR "At least ROCM/HIP V5.5 is required") # message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
endif() #endif()
message(STATUS "HIP and hipBLAS found") message(STATUS "HIP and hipBLAS found")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment