Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
12494cf5
Commit
12494cf5
authored
Dec 23, 2024
by
xuxzh1
🎱
Browse files
adapt v3.0.0
parent
8f326c97
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
419 additions
and
14 deletions
+419
-14
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_5.cuh
...r/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_5.cuh
+209
-0
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_6.cuh
...r/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_6.cuh
+44
-0
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_8.cuh
...r/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_8.cuh
+40
-0
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_util.cuh
...xllamav2_kernels/exllamav2_kernels/hip/quant/qdq_util.cuh
+55
-0
server/exllamav2_kernels/exllamav2_kernels/hip/util.cuh
server/exllamav2_kernels/exllamav2_kernels/hip/util.cuh
+56
-0
server/text_generation_server/layers/layernorm.py
server/text_generation_server/layers/layernorm.py
+2
-1
server/text_generation_server/layers/linear.py
server/text_generation_server/layers/linear.py
+13
-13
No files found.
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_5.cuh
0 → 100644
View file @
12494cf5
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__
__device__
void
shuffle_5bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
uint32_t
qd
=
q
[
3
*
stride
];
uint32_t
qe
=
q
[
4
*
stride
];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t
qf
=
qe
>>
22
;
qe
<<=
8
;
qe
|=
qd
>>
24
;
qd
<<=
6
;
qd
|=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
uint32_t
zd
=
0
;
uint32_t
ze
=
0
;
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qa
&
0x1f
;
uint32_t
t1
=
(
qa
&
0x3e0
)
>>
5
;
qa
>>=
10
;
za
|=
(
t0
<<
(
i
*
5
));
za
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qb
&
0x1f
;
uint32_t
t1
=
(
qb
&
0x3e0
)
>>
5
;
qb
>>=
10
;
zb
|=
(
t0
<<
(
i
*
5
));
zb
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qc
&
0x1f
;
uint32_t
t1
=
(
qc
&
0x3e0
)
>>
5
;
qc
>>=
10
;
zc
|=
(
t0
<<
(
i
*
5
));
zc
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qd
&
0x1f
;
uint32_t
t1
=
(
qd
&
0x3e0
)
>>
5
;
qd
>>=
10
;
zd
|=
(
t0
<<
(
i
*
5
));
zd
|=
(
t1
<<
(
i
*
5
+
16
));
}
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
uint32_t
t0
=
qe
&
0x1f
;
uint32_t
t1
=
(
qe
&
0x3e0
)
>>
5
;
qe
>>=
10
;
ze
|=
(
t0
<<
(
i
*
5
));
ze
|=
(
t1
<<
(
i
*
5
+
16
));
}
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za
|=
((
qf
&
0x001
)
>>
0
)
<<
15
;
zb
|=
((
qf
&
0x002
)
>>
1
)
<<
15
;
zc
|=
((
qf
&
0x004
)
>>
2
)
<<
15
;
zd
|=
((
qf
&
0x008
)
>>
3
)
<<
15
;
ze
|=
((
qf
&
0x010
)
>>
4
)
<<
15
;
za
|=
((
qf
&
0x020
)
>>
5
)
<<
31
;
zb
|=
((
qf
&
0x040
)
>>
6
)
<<
31
;
zc
|=
((
qf
&
0x080
)
>>
7
)
<<
31
;
zd
|=
((
qf
&
0x100
)
>>
8
)
<<
31
;
ze
|=
((
qf
&
0x200
)
>>
9
)
<<
31
;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
q
[
3
*
stride
]
=
zd
;
q
[
4
*
stride
]
=
ze
;
}
__forceinline__
__device__
void
dequant_5bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
const
uint32_t
q_3
,
const
uint32_t
q_4
,
half2
(
&
dq
)[
16
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y32_
=
__float2half_rn
(
1.0
f
/
32.0
f
);
const
half2
y32
=
__halves2half2
(
y32_
,
y32_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
16.0
f
);
const
half
z32_
=
__float2half_rn
(
-
1024.0
f
/
32.0
f
-
16.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z32
=
__halves2half2
(
z32_
,
z32_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
uint32_t
qd
=
q_3
;
uint32_t
qe
=
q_4
;
half2_uint32
q0
((
qa
&
0x001f001f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x03e003e0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 32 + 1024
qa
>>=
10
;
half2_uint32
q2
((
qa
&
0x001f001f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
qa
>>=
5
;
qa
&=
0x00010001
;
half2_uint32
q3
((
qb
&
0x001f001f
)
|
c0
);
// half2(q[ 6], q[ 7]) + 1024
half2_uint32
q4
((
qb
&
0x03e003e0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 32 + 1024
qb
>>=
10
;
half2_uint32
q5
((
qb
&
0x001f001f
)
|
c0
);
// half2(q[10], q[11]) + 1024
qb
>>=
4
;
qb
&=
0x00020002
;
half2_uint32
q6
((
qc
&
0x001f001f
)
|
c0
);
// half2(q[12], q[13]) + 1024
half2_uint32
q7
((
qc
&
0x03e003e0
)
|
c0
);
// half2(q[14], q[15]) * 32 + 1024
qc
>>=
10
;
half2_uint32
q8
((
qc
&
0x001f001f
)
|
c0
);
// half2(q[16], q[17]) + 1024
qc
>>=
3
;
qc
&=
0x00040004
;
half2_uint32
q9
((
qd
&
0x001f001f
)
|
c0
);
// half2(q[18], q[19]) + 1024
half2_uint32
q10
((
qd
&
0x03e003e0
)
|
c0
);
// half2(q[20], q[21]) * 32 + 1024
qd
>>=
10
;
half2_uint32
q11
((
qd
&
0x001f001f
)
|
c0
);
// half2(q[22], q[23]) + 1024
qd
>>=
2
;
qd
&=
0x00080008
;
half2_uint32
q12
((
qe
&
0x001f001f
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qe
&
0x03e003e0
)
|
c0
);
// half2(q[26], q[27]) * 32 + 1024
qe
>>=
10
;
half2_uint32
q14
((
qe
&
0x001f001f
)
|
c0
);
// half2(q[28], q[29]) + 1024
qe
>>=
1
;
qe
&=
0x00100010
;
half2_uint32
q15
((
qa
|
qb
|
qc
|
qd
|
qe
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y32
,
z32
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hadd2
(
q3
.
as_half2
,
z1
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y32
,
z32
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hadd2
(
q6
.
as_half2
,
z1
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y32
,
z32
);
dq
[
8
]
=
__hadd2
(
q8
.
as_half2
,
z1
);
dq
[
9
]
=
__hadd2
(
q9
.
as_half2
,
z1
);
dq
[
10
]
=
__hfma2
(
q10
.
as_half2
,
y32
,
z32
);
dq
[
11
]
=
__hadd2
(
q11
.
as_half2
,
z1
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y32
,
z32
);
dq
[
14
]
=
__hadd2
(
q14
.
as_half2
,
z1
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
#else
__forceinline__
__device__
void
shuffle_5bit_32
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_5bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
const
uint32_t
q_3
,
const
uint32_t
q_4
,
half2
(
&
dq
)[
16
],
int
stride
)
{
half
dqh
[
32
];
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
5
,
0x1f
),
16
);
dqh
[
6
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
7
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
5
+
3
,
0x1f
),
16
);
dqh
[
12
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
28
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
13
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
5
+
1
,
0x1f
),
16
);
dqh
[
19
]
=
dq_ns
(
exb
(
q_3
,
q_2
,
31
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
20
+
i
]
=
dq_ns
(
exb
(
q_3
,
i
*
5
+
4
,
0x1f
),
16
);
dqh
[
25
]
=
dq_ns
(
exb
(
q_4
,
q_3
,
29
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
6
;
i
++
)
dqh
[
26
+
i
]
=
dq_ns
(
exb
(
q_4
,
i
*
5
+
2
,
0x1f
),
16
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_6.cuh
0 → 100644
View file @
12494cf5
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__
__device__
void
shuffle_6bit_16
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_6bit_16
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
8
],
int
stride
)
{
half
dqh
[
16
];
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
6
,
0x3f
),
32
);
dqh
[
5
]
=
dq_ns
(
exb
(
q_1
,
q_0
,
30
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
6
+
i
]
=
dq_ns
(
exb
(
q_1
,
i
*
6
+
4
,
0x3f
),
32
);
dqh
[
10
]
=
dq_ns
(
exb
(
q_2
,
q_1
,
28
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
5
;
i
++
)
dqh
[
11
+
i
]
=
dq_ns
(
exb
(
q_2
,
i
*
6
+
2
,
0x3f
),
32
);
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_8.cuh
0 → 100644
View file @
12494cf5
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__
__device__
void
shuffle_8bit_4
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_8bit_8
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
8
,
0xff
),
128
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
+
4
]
=
dq_ns
(
exb
(
q_1
,
i
*
8
,
0xff
),
128
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
#endif
#endif
server/exllamav2_kernels/exllamav2_kernels/hip/quant/qdq_util.cuh
0 → 100644
View file @
12494cf5
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union
half2_uint32
{
uint32_t
as_uint32
;
half2
as_half2
;
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
__device__
half2_uint32
()
:
as_uint32
(
0
)
{}
};
union
half_uint16
{
uint16_t
as_uint16
;
half
as_half
;
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
__device__
half_uint16
()
:
as_uint16
(
0
)
{}
};
// Max_scale premultiplied by 1/256
__forceinline__
__device__
half
dq_scale
(
const
int
qs
,
const
half
max_scale
)
{
int
qs_i
=
qs
+
1
;
half
qs_h
=
__int2half_rn
(
qs_i
*
qs_i
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
half
dq
(
const
int
q
,
const
int
qzero
,
const
half
scale
)
{
return
__hmul
(
__int2half_rn
(
q
-
qzero
),
scale
);
}
__forceinline__
__device__
half
dq_ns
(
const
int
q
,
const
int
qzero
)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return
__int2half_rn
(
q
-
qzero
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)((
q
>>
shift
)
&
mask
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q1
,
const
uint32_t
q0
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
}
#endif
server/exllamav2_kernels/exllamav2_kernels/hip/util.cuh
0 → 100644
View file @
12494cf5
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _util_cuh
#define _util_cuh
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/hip/HIPContext.h>
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
__forceinline__
__device__
half
dq_scale_
(
const
int
qs
,
const
half
max_scale
)
{
half
qs_h
=
__hmul
(
__int2half_rn
(
qs
+
1
),
__float2half_rn
(
1.0
f
/
16.0
f
));
qs_h
=
__hmul
(
qs_h
,
qs_h
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
float
clamp
(
float
x
,
float
a
,
float
b
)
{
return
fmaxf
(
a
,
fminf
(
b
,
x
));
}
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
inline
void
gpu_assert
(
hipError_t
code
,
const
char
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
hipSuccess
)
{
fprintf
(
stderr
,
"CUDA error: %s %s %d
\n
"
,
hipGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
exit
(
code
);
}
}
void
print_global_mem
(
const
half
*
ptr
,
int
rows
,
int
columns
,
int
stride
);
#endif
server/text_generation_server/layers/layernorm.py
View file @
12494cf5
...
...
@@ -72,7 +72,8 @@ if SYSTEM == "cuda":
return
normed_hidden_states
,
residual
elif
SYSTEM
==
"rocm"
:
from
vllm._C
import
ops
#from vllm._C import ops
from
vllm
import
_custom_ops
class
FastLayerNorm
(
nn
.
LayerNorm
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
...
...
server/text_generation_server/layers/linear.py
View file @
12494cf5
...
...
@@ -3,19 +3,19 @@ from text_generation_server.utils.import_utils import SYSTEM
from
torch.nn
import
functional
as
F
import
os
if
SYSTEM
==
"rocm"
:
ROCM_USE_SKINNY_GEMM
=
os
.
getenv
(
"ROCM_USE_SKINNY_GEMM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
,
)
if
ROCM_USE_SKINNY_GEMM
:
try
:
from
vllm
import
_custom_C
except
Exception
as
e
:
raise
ImportError
(
f
"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error:
{
e
}
"
)
#
if SYSTEM == "rocm":
#
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
#
"true",
#
"1",
#
)
#
if ROCM_USE_SKINNY_GEMM:
#
try:
#
from vllm import _custom_C
#
except Exception as e:
#
raise ImportError(
#
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
#
)
class
FastLinear
(
torch
.
nn
.
Module
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment