

# 下载jax和jaxlib

```
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp39-cp39-manylinux_2_31_x86_64.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl

```


# conda 环境

```
conda create -n pymc3 python=3.9
conda activate pymc3
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp39-cp39-manylinux_2_31_x86_64.whl
pip install  pymc==5.9.1
pip install numpyro==0.14.0
pip install seaborn==0.13.2
pip install scipy==1.12.0 -i https://pypi.tuna.tsinghua.edu.cn/simple 
```
环境安装包如下：


```
arviz               0.17.1
cachetools          5.4.0
cloudpickle         3.0.0
cons                0.4.6
contourpy           1.2.1
cycler              0.12.1
etuples             0.3.9
fastprogress        1.0.3
filelock            3.15.4
fonttools           4.53.1
h5netcdf            1.3.0
h5py                3.11.0
importlib_metadata  8.2.0
importlib_resources 6.4.2
jax                 0.4.23+das1.1.git387bd43.abi1.dtk2404
jaxlib              0.4.23+das1.1.git387bd43.abi1.dtk2404
kiwisolver          1.4.5
logical-unification 0.4.6
matplotlib          3.9.2
miniKanren          1.0.3
ml-dtypes           0.4.0
multipledispatch    1.0.0
numpy               1.26.4
numpyro             0.14.0
opt-einsum          3.3.0
packaging           24.1
pandas              2.2.2
pillow              10.4.0
pip                 24.2
pymc                5.9.1
pyparsing           3.1.2
pytensor            2.17.4
python-dateutil     2.9.0.post0
pytz                2024.1
scipy               1.12.0
seaborn             0.13.2
setuptools          72.1.0
six                 1.16.0
toolz               0.12.1
tqdm                4.66.5
typing_extensions   4.12.2
tzdata              2024.1
wheel               0.43.0
xarray              2024.7.0
xarray-einstats     0.7.0
zipp                3.20.0
```



# 测试jax是否在gpu上可行：
  ```
import jax
import pymc3 as pm
jax.default_backend()
jax.devices()
```

# 测试采样时间，GPU负载等情况

```
import jax
import pymc as pm
import numpy as np
import pytensor as pt
pt.config.floatX = "float32"
np.random.seed(123) 
n =10000
X = np.random.randn(n) 
Y =3* X + np.random.randn(n)
# 定义PyMC3模型
with pm.Model() as model:
    alpha = pm.Normal('alpha', mu=0, sigma=10)     
    beta = pm.Normal('beta', mu=0, sigma=10)     
    sigma = pm.HalfNormal('sigma', sigma=1)
    mu = alpha + beta * X
    Y_obs = pm.Normal('Y_obs', mu=mu, sigma=sigma, observed=Y)
    trace = pm.sample(1000, nuts_sampler="numpyro",return_inferencedata=False)
print(pm.summary(trace))
```

   