Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2221f4ce
Commit
2221f4ce
authored
Jul 10, 2023
by
Tim Dettmers
Browse files
Fixed potential memory leak.
parent
490153b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
4 deletions
+9
-4
csrc/kernels.cu
csrc/kernels.cu
+9
-4
No files found.
csrc/kernels.cu
View file @
2221f4ce
...
@@ -3561,7 +3561,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3561,7 +3561,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
if
(
row_B
<
M
)
if
(
row_B
<
M
)
{
{
if
((
inner_idx_halved
+
num_values_8bit
)
<
K
)
if
((
inner_idx_halved
+
num_values_8bit
)
<
(
K
/
2
)
)
{
{
reinterpret_cast
<
int4
(
&
)[
num_values_8bit
]
>
(
local_B_4bit
)[
0
]
=
reinterpret_cast
<
int4
*>
(
B
)[(
offset_B
+
(
inner_idx_halved
))
/
(
num_values_8bit
)];
reinterpret_cast
<
int4
(
&
)[
num_values_8bit
]
>
(
local_B_4bit
)[
0
]
=
reinterpret_cast
<
int4
*>
(
B
)[(
offset_B
+
(
inner_idx_halved
))
/
(
num_values_8bit
)];
}
}
...
@@ -3569,15 +3569,21 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3569,15 +3569,21 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
(
num_values_8bit
);
j
++
)
for
(
int
j
=
0
;
j
<
(
num_values_8bit
);
j
++
)
if
((
inner_idx_halved
)
+
j
<
K
)
if
((
inner_idx_halved
)
+
j
<
(
K
/
2
)
)
local_B_4bit
[
j
]
=
B
[
offset_B
+
inner_idx_halved
+
j
];
local_B_4bit
[
j
]
=
B
[
offset_B
+
inner_idx_halved
+
j
];
else
else
local_B_4bit
[
j
]
=
0b01110111
;
local_B_4bit
[
j
]
=
0b01110111
;
}
}
}
}
else
{
#pragma unroll
for
(
int
j
=
0
;
j
<
(
num_values_8bit
);
j
++
)
local_B_4bit
[
j
]
=
0b01110111
;
}
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_
4
bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_
8
bit
;
k
++
)
{
{
#if __CUDA_ARCH__ >= 800
#if __CUDA_ARCH__ >= 800
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
...
@@ -3625,7 +3631,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3625,7 +3631,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
local_C
+=
((
float
)
local_A
[
k
]
*
(
float
)
local_B
[
k
]);
local_C
+=
((
float
)
local_A
[
k
]
*
(
float
)
local_B
[
k
]);
#endif
#endif
}
}
}
}
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
local_C
=
WarpReduce
(
temp_storage
[
warp_idx
]).
Sum
(
local_C
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment